diff options
| author | Donggyu | 2016-09-08 15:48:22 -0700 |
|---|---|---|
| committer | GitHub | 2016-09-08 15:48:22 -0700 |
| commit | 765b880d4a56875c1ed07f4a0e8904c74a92dc0b (patch) | |
| tree | 6f11b15ed7516bc8816ec0d45f505dd0e4014613 /src | |
| parent | 303bad7db4354429c1992233fe0bfd1e8ce7f93e (diff) | |
| parent | 864a3978cf94f336187831773dfc2c9f9ea064c8 (diff) | |
Merge pull request #283 from ucb-bar/refactor_expand_whens
Refactor Passes
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 20 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 22 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 208 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferTypes.scala | 161 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 328 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 941 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 30 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveCHIRRTL.scala | 256 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Resolves.scala | 163 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 30 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/UnitTests.scala | 5 |
11 files changed, 1061 insertions, 1103 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 378eac6d..b5d212e4 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -62,6 +62,26 @@ case class VRandom(width: BigInt) extends Expression { } class VerilogEmitter extends Emitter { val tab = " " + def AND(e1: WrappedExpression, e2: WrappedExpression): Expression = { + if (e1 == e2) e1.e1 + else if ((e1 == we(zero)) | (e2 == we(zero))) zero + else if (e1 == we(one)) e2.e1 + else if (e2 == we(one)) e1.e1 + else DoPrim(And, Seq(e1.e1, e2.e1), Nil, UIntType(IntWidth(1))) + } + def OR(e1: WrappedExpression, e2: WrappedExpression): Expression = { + if (e1 == e2) e1.e1 + else if ((e1 == we(one)) | (e2 == we(one))) one + else if (e1 == we(zero)) e2.e1 + else if (e2 == we(zero)) e1.e1 + else DoPrim(Or, Seq(e1.e1, e2.e1), Nil, UIntType(IntWidth(1))) + } + def NOT(e: WrappedExpression): Expression = { + if (e == we(one)) zero + else if (e == we(zero)) one + else DoPrim(Eq, Seq(e.e1, zero), Nil, UIntType(IntWidth(1))) + } + def wref(n: String, t: Type) = WRef(n, t, ExpKind(), UNKNOWNGENDER) def remove_root(ex: Expression): Expression = ex match { case ex: WSubField => ex.exp match { diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 29c37294..572d1ccc 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -98,28 +98,6 @@ object Utils extends LazyLogging { val ix = if (i < 0) ((-1 * i) - 1) else i ceil_log2(ix + 1) + 1 } - def EQV (e1:Expression,e2:Expression) : Expression = - DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe) - // TODO: these should be fixed - def AND (e1:WrappedExpression,e2:WrappedExpression) : Expression = { - if (e1 == e2) e1.e1 - else if ((e1 == we(zero)) | (e2 == we(zero))) zero - else if (e1 == we(one)) e2.e1 - else if (e2 == we(one)) e1.e1 - else DoPrim(And,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) - } - def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = { - if (e1 == e2) e1.e1 - else if ((e1 == we(one)) | (e2 == we(one))) one - else if (e1 == we(zero)) e2.e1 - else if (e2 == we(zero)) e1.e1 - else DoPrim(Or,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) - } - def NOT (e1:WrappedExpression) : Expression = { - if (e1 == we(one)) zero - else if (e1 == we(zero)) one - else DoPrim(Eq,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1))) - } def create_mask(dt: Type): Type = dt match { case t: VectorType => VectorType(create_mask(t.tpe),t.size) diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 3d26298a..7c013b51 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -34,11 +34,6 @@ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.WrappedExpression._ -// Datastructures -import scala.collection.mutable.HashMap -import scala.collection.mutable.LinkedHashMap -import scala.collection.mutable.ArrayBuffer - import annotation.tailrec /** Expand Whens @@ -50,138 +45,129 @@ import annotation.tailrec */ object ExpandWhens extends Pass { def name = "Expand Whens" + type NodeMap = collection.mutable.HashMap[MemoizedHash[Expression], String] + type Netlist = collection.mutable.LinkedHashMap[WrappedExpression, Expression] + type Simlist = collection.mutable.ArrayBuffer[Statement] + type Defaults = Seq[collection.mutable.Map[WrappedExpression, Expression]] // ========== Expand When Utilz ========== - private def getEntries( - hash: LinkedHashMap[WrappedExpression, Expression], - exps: Seq[Expression]): LinkedHashMap[WrappedExpression, Expression] = { - val hashx = LinkedHashMap[WrappedExpression, Expression]() - exps foreach (e => if (hash.contains(e)) hashx(e) = hash(e)) - hashx - } private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = { def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, Default)) val exps = create_exps(WRef(n, t, ExpKind(), g)) - val expsx = ArrayBuffer[Expression]() - for (j <- 0 until exps.size) { - getGender(t, j, g) match { - case (BIGENDER | FEMALE) => expsx += exps(j) - case _ => + (exps.zipWithIndex foldLeft Seq[Expression]()){ + case (expsx, (exp, j)) => getGender(t, j, g) match { + case (BIGENDER | FEMALE) => expsx :+ exp + case _ => expsx } } - expsx } - private def expandNetlist(netlist: LinkedHashMap[WrappedExpression, Expression]) = - netlist map { case (k, v) => - v match { - case WInvalid() => IsInvalid(NoInfo, k.e1) - case _ => Connect(NoInfo, k.e1, v) - } + private def expandNetlist(netlist: Netlist) = + netlist map { + case (k, WInvalid()) => IsInvalid(NoInfo, k.e1) + case (k, v) => Connect(NoInfo, k.e1, v) } // Searches nested scopes of defaults for lvalue // defaults uses mutable Map because we are searching LinkedHashMaps and conversion to immutable is VERY slow @tailrec - private def getDefault( - lvalue: WrappedExpression, - defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]]): Option[Expression] = { - if (defaults.isEmpty) None - else if (defaults.head.contains(lvalue)) defaults.head.get(lvalue) - else getDefault(lvalue, defaults.tail) + private def getDefault(lvalue: WrappedExpression, defaults: Defaults): Option[Expression] = { + defaults match { + case Nil => None + case head :: tail => head get lvalue match { + case Some(p) => Some(p) + case None => getDefault(lvalue, tail) + } + } } + private def AND(e1: Expression, e2: Expression) = + DoPrim(And, Seq(e1, e2), Nil, UIntType(IntWidth(1))) + private def NOT(e: Expression) = + DoPrim(Eq, Seq(e, zero), Nil, UIntType(IntWidth(1))) + // ------------ Pass ------------------- def run(c: Circuit): Circuit = { - def expandWhens(m: Module): (LinkedHashMap[WrappedExpression, Expression], ArrayBuffer[Statement], Statement) = { + def expandWhens(m: Module): (Netlist, Simlist, Statement) = { val namespace = Namespace(m) - val simlist = ArrayBuffer[Statement]() + val simlist = new Simlist + val nodes = new NodeMap // defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow - def expandWhens( - netlist: LinkedHashMap[WrappedExpression, Expression], - defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]], - p: Expression) - (s: Statement): Statement = { - s match { - case w: DefWire => - getFemaleRefs(w.name, w.tpe, BIGENDER) foreach (ref => netlist(ref) = WVoid()) - w - case r: DefRegister => - getFemaleRefs(r.name, r.tpe, BIGENDER) foreach (ref => netlist(ref) = ref) - r - case c: Connect => - netlist(c.loc) = c.expr - EmptyStmt - case c: IsInvalid => - netlist(c.expr) = WInvalid() - EmptyStmt - case s: Conditionally => - val memos = ArrayBuffer[Statement]() - - val conseqNetlist = LinkedHashMap[WrappedExpression, Expression]() - val altNetlist = LinkedHashMap[WrappedExpression, Expression]() - val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq) - val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt) + def expandWhens(netlist: Netlist, + defaults: Defaults, + p: Expression) + (s: Statement): Statement = s match { + case w: DefWire => + netlist ++= (getFemaleRefs(w.name, w.tpe, BIGENDER) map (ref => we(ref) -> WVoid())) + w + case r: DefRegister => + netlist ++= (getFemaleRefs(r.name, r.tpe, BIGENDER) map (ref => we(ref) -> ref)) + r + case c: Connect => + netlist(c.loc) = c.expr + EmptyStmt + case c: IsInvalid => + netlist(c.expr) = WInvalid() + EmptyStmt + case s: Conditionally => + val conseqNetlist = new Netlist + val altNetlist = new Netlist + val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq) + val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt) - (conseqNetlist.keySet ++ altNetlist.keySet) foreach { lvalue => - // Defaults in netlist get priority over those in defaults - val default = if (netlist.contains(lvalue)) netlist.get(lvalue) else getDefault(lvalue, defaults) - val res = default match { - case Some(defaultValue) => - val trueValue = conseqNetlist.getOrElse(lvalue, defaultValue) - val falseValue = altNetlist.getOrElse(lvalue, defaultValue) - (trueValue, falseValue) match { - case (WInvalid(), WInvalid()) => WInvalid() - case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe) - case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe) - case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv)) - } - case None => - // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt - conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)) - } - - val memoNode = DefNode(s.info, namespace.newTemp, res) - val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE) - memos += memoNode - netlist(lvalue) = memoExpr + val memos = (conseqNetlist.keys ++ altNetlist.keys) map { lvalue => + // Defaults in netlist get priority over those in defaults + val default = netlist get lvalue match { + case Some(v) => Some(v) + case None => getDefault(lvalue, defaults) } - Block(Seq(conseqStmt, altStmt) ++ memos) - - case s: Print => - if(weq(p, one)) { - simlist += s - } else { - simlist += Print(s.info, s.string, s.args, s.clk, AND(p, s.en)) + val res = default match { + case Some(defaultValue) => + val trueValue = conseqNetlist getOrElse (lvalue, defaultValue) + val falseValue = altNetlist getOrElse (lvalue, defaultValue) + (trueValue, falseValue) match { + case (WInvalid(), WInvalid()) => WInvalid() + case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe) + case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe) + case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv)) + } + case None => + // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt + conseqNetlist getOrElse (lvalue, altNetlist(lvalue)) } - EmptyStmt - case s: Stop => - if (weq(p, one)) { - simlist += s - } else { - simlist += Stop(s.info, s.ret, s.clk, AND(p, s.en)) + + nodes get res match { + case Some(name) => + netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE) + EmptyStmt + case None => + val name = namespace.newTemp + nodes(res) = name + netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE) + DefNode(s.info, name, res) } - EmptyStmt - case s => s map expandWhens(netlist, defaults, p) - } + } + Block(Seq(conseqStmt, altStmt) ++ memos) + case s: Print => + simlist += (if (weq(p, one)) s else Print(s.info, s.string, s.args, s.clk, AND(p, s.en))) + EmptyStmt + case s: Stop => + simlist += (if (weq(p, one)) s else Stop(s.info, s.ret, s.clk, AND(p, s.en))) + EmptyStmt + case s => s map expandWhens(netlist, defaults, p) } - val netlist = LinkedHashMap[WrappedExpression, Expression]() - + val netlist = new Netlist // Add ports to netlist - m.ports foreach { port => - getFemaleRefs(port.name, port.tpe, to_gender(port.direction)) foreach (ref => netlist(ref) = WVoid()) - } - val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) - - (netlist, simlist, bodyx) + netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) => + getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid()) + }) + (netlist, simlist, expandWhens(netlist, Seq(netlist), one)(m.body)) } - val modulesx = c.modules map { m => - m match { - case m: ExtModule => m - case m: Module => - val (netlist, simlist, bodyx) = expandWhens(m) - val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist) - Module(m.info, m.name, m.ports, newBody) - } + val modulesx = c.modules map { + case m: ExtModule => m + case m: Module => + val (netlist, simlist, bodyx) = expandWhens(m) + val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist) + Module(m.info, m.name, m.ports, newBody) } Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala new file mode 100644 index 00000000..b36298e8 --- /dev/null +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -0,0 +1,161 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtl.passes + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ + +object InferTypes extends Pass { + def name = "Infer Types" + type TypeMap = collection.mutable.LinkedHashMap[String, Type] + + def run(c: Circuit): Circuit = { + val namespace = Namespace() + val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap + + def remove_unknowns_w(w: Width): Width = w match { + case UnknownWidth => VarWidth(namespace.newName("w")) + case w => w + } + + def remove_unknowns(t: Type): Type = + t map remove_unknowns map remove_unknowns_w + + def infer_types_e(types: TypeMap)(e: Expression): Expression = + e map infer_types_e(types) match { + case e: WRef => e copy (tpe = types(e.name)) + case e: WSubField => e copy (tpe = field_type(e.exp.tpe, e.name)) + case e: WSubIndex => e copy (tpe = sub_type(e.exp.tpe)) + case e: WSubAccess => e copy (tpe = sub_type(e.exp.tpe)) + case e: DoPrim => PrimOps.set_primop_type(e) + case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval)) + case e: ValidIf => e copy (tpe = e.value.tpe) + case e @ (_: UIntLiteral | _: SIntLiteral) => e + } + + def infer_types_s(types: TypeMap)(s: Statement): Statement = s match { + case s: WDefInstance => + val t = mtypes(s.module) + types(s.name) = t + s copy (tpe = t) + case s: DefWire => + val t = remove_unknowns(get_type(s)) + types(s.name) = t + s copy (tpe = t) + case s: DefNode => + val sx = s map infer_types_e(types) + val t = remove_unknowns(get_type(sx)) + types(s.name) = t + sx map infer_types_e(types) + case s: DefRegister => + val t = remove_unknowns(get_type(s)) + types(s.name) = t + s copy (tpe = t) map infer_types_e(types) + case s: DefMemory => + val t = remove_unknowns(get_type(s)) + types(s.name) = t + s copy (dataType = remove_unknowns(s.dataType)) + case s => s map infer_types_s(types) map infer_types_e(types) + } + + def infer_types_p(types: TypeMap)(p: Port): Port = { + val t = remove_unknowns(p.tpe) + types(p.name) = t + p copy (tpe = t) + } + + def infer_types(m: DefModule): DefModule = { + val types = new TypeMap + m map infer_types_p(types) map infer_types_s(types) + } + + c copy (modules = (c.modules map infer_types)) + } +} + +object CInferTypes extends Pass { + def name = "CInfer Types" + type TypeMap = collection.mutable.LinkedHashMap[String, Type] + + def run(c: Circuit): Circuit = { + val namespace = Namespace() + val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap + + def infer_types_e(types: TypeMap)(e: Expression) : Expression = + e map infer_types_e(types) match { + case (e: Reference) => e copy (tpe = (types getOrElse (e.name, UnknownType))) + case (e: SubField) => e copy (tpe = field_type(e.expr.tpe, e.name)) + case (e: SubIndex) => e copy (tpe = sub_type(e.expr.tpe)) + case (e: SubAccess) => e copy (tpe = sub_type(e.expr.tpe)) + case (e: DoPrim) => PrimOps.set_primop_type(e) + case (e: Mux) => e copy (tpe = mux_type(e.tval,e.tval)) + case (e: ValidIf) => e copy (tpe = e.value.tpe) + case e @ (_: UIntLiteral | _: SIntLiteral) => e + } + + def infer_types_s(types: TypeMap)(s: Statement): Statement = s match { + case (s: DefRegister) => + types(s.name) = s.tpe + s map infer_types_e(types) + case (s: DefWire) => + types(s.name) = s.tpe + s + case (s: DefNode) => + types(s.name) = get_type(s) + s + case (s: DefMemory) => + types(s.name) = get_type(s) + s + case (s: CDefMPort) => + val t = types getOrElse(s.mem, UnknownType) + types(s.name) = t + s copy (tpe = t) + case (s: CDefMemory) => + types(s.name) = s.tpe + s + case (s: DefInstance) => + types(s.name) = mtypes(s.module) + s + case (s) => s map infer_types_s(types) map infer_types_e(types) + } + + def infer_types_p(types: TypeMap)(p: Port): Port = { + types(p.name) = p.tpe + p + } + + def infer_types(m: DefModule): DefModule = { + val types = new TypeMap + m map infer_types_p(types) map infer_types_s(types) + } + + c copy (modules = (c.modules map infer_types)) + } +} diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala new file mode 100644 index 00000000..5a81c268 --- /dev/null +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -0,0 +1,328 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtl.passes + +// Datastructures +import scala.collection.mutable.{LinkedHashMap, HashMap, HashSet, ArrayBuffer} +import scala.collection.immutable.ListMap + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.PrimOps._ +import firrtl.WrappedExpression._ + +object InferWidths extends Pass { + def name = "Infer Widths" + + def solve_constraints(l: Seq[WGeq]): LinkedHashMap[String, Width] = { + def unique(ls: Seq[Width]) : Seq[Width] = + (ls map (new WrappedWidth(_))).distinct map (_.w) + def make_unique(ls: Seq[WGeq]): ListMap[String,Width] = { + (ls foldLeft ListMap[String, Width]())((h, g) => g.loc match { + case w: VarWidth => h get w.name match { + case None => h + (w.name -> g.exp) + case Some(p) => h + (w.name -> MaxWidth(Seq(g.exp, p))) + } + case _ => h + }) + } + def simplify(w: Width): Width = w map simplify match { + case (w: MinWidth) => MinWidth(unique((w.args foldLeft Seq[Width]()){ + case (res, w: MinWidth) => res ++ w.args + case (res, w) => res :+ w + })) + case (w: MaxWidth) => MaxWidth(unique((w.args foldLeft Seq[Width]()){ + case (res, w: MaxWidth) => res ++ w.args + case (res, w) => res :+ w + })) + case (w: PlusWidth) => (w.arg1, w.arg2) match { + case (w1: IntWidth, w2 :IntWidth) => IntWidth(w1.width + w2.width) + case _ => w + } + case (w: MinusWidth) => (w.arg1, w.arg2) match { + case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) + case _ => w + } + case (w: ExpWidth) => w.arg1 match { + case (w1: IntWidth) => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong)) + case (w1) => w + } + case _ => w + } + + def substitute(h: LinkedHashMap[String, Width])(w: Width): Width = { + //;println-all-debug(["Substituting for [" w "]"]) + val wx = simplify(w) + //;println-all-debug(["After Simplify: [" wx "]"]) + (wx map substitute(h)) match { + //;("matched println-debugvarwidth!") + case w: VarWidth => h get w.name match { + case None => w + case Some(p) => + //;println-debug("Contained!") + //;println-all-debug(["Width: " w]) + //;println-all-debug(["Accessed: " h[name(w)]]) + val t = simplify(substitute(h)(p)) + h(w.name) = t + t + } + case w => w + //;println-all-debug(["not varwidth!" w]) + } + } + + def b_sub(h: LinkedHashMap[String, Width])(w: Width): Width = { + w map b_sub(h) match { + case w: VarWidth => h getOrElse (w.name, w) + case w => w + } + } + + def remove_cycle(n: String)(w: Width): Width = { + //;println-all-debug(["Removing cycle for " n " inside " w]) + (w map remove_cycle(n)) match { + case w: MaxWidth => MaxWidth(w.args filter { + case w: VarWidth => !(n equals w.name) + case w => true + }) + case w: MinusWidth => w.arg1 match { + case v: VarWidth if n == v.name => v + case v => w + } + case w => w + } + //;println-all-debug(["After removing cycle for " n ", returning " wx]) + } + + def hasVarWidth(n: String)(w: Width): Boolean = { + var has = false + def rec(w: Width): Width = { + w match { + case w: VarWidth if w.name == n => has = true + case w => + } + w map rec + } + rec(w) + has + } + + //; Forward solve + //; Returns a solved list where each constraint undergoes: + //; 1) Continuous Solving (using triangular solving) + //; 2) Remove Cycles + //; 3) Move to solved if not self-recursive + val u = make_unique(l) + + //println("======== UNIQUE CONSTRAINTS ========") + //for (x <- u) { println(x) } + //println("====================================") + + val f = LinkedHashMap[String, Width]() + val o = ArrayBuffer[String]() + for ((n, e) <- u) { + //println("==== SOLUTIONS TABLE ====") + //for (x <- f) println(x) + //println("=========================") + + val e_sub = substitute(f)(e) + + //println("Solving " + n + " => " + e) + //println("After Substitute: " + n + " => " + e_sub) + //println("==== SOLUTIONS TABLE (Post Substitute) ====") + //for (x <- f) println(x) + //println("=========================") + + val ex = remove_cycle(n)(e_sub) + + //println("After Remove Cycle: " + n + " => " + ex) + if (!hasVarWidth(n)(ex)) { + //println("Not rec!: " + n + " => " + ex) + //println("Adding [" + n + "=>" + ex + "] to Solutions Table") + f(n) = ex + o += n + } + } + + //println("Forward Solved Constraints") + //for (x <- f) println(x) + + //; Backwards Solve + val b = LinkedHashMap[String, Width]() + for (i <- (o.size - 1) to 0 by -1) { + val n = o(i) // Should visit `o` backward + /* + println("SOLVE BACK: [" + n + " => " + f(n) + "]") + println("==== SOLUTIONS TABLE ====") + for (x <- b) println(x) + println("=========================") + */ + val ex = simplify(b_sub(b)(f(n))) + /* + println("BACK RETURN: [" + n + " => " + ex + "]") + */ + b(n) = ex + /* + println("==== SOLUTIONS TABLE (Post backsolve) ====") + for (x <- b) println(x) + println("=========================") + */ + } + b + } + + def run (c: Circuit): Circuit = { + val v = ArrayBuffer[WGeq]() + + def get_constraints_t(t1: Type, t2: Type, f: Orientation): Seq[WGeq] = (t1,t2) match { + case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width)) + case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width)) + case (t1: BundleType, t2: BundleType) => + (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) => + res ++ get_constraints_t(f1.tpe, f2.tpe, times(f1.flip, f)) + } + case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe, f) + } + + def get_constraints_e(e: Expression): Expression = { + e match { + case (e: Mux) => v ++= Seq( + WGeq(width_BANG(e.cond), IntWidth(1)), + WGeq(IntWidth(1), width_BANG(e.cond)) + ) + case _ => + } + e map get_constraints_e + } + + def get_constraints_s(s: Statement): Statement = { + s match { + case (s: Connect) => + val n = get_size(s.loc.tpe) + val locs = create_exps(s.loc) + val exps = create_exps(s.expr) + v ++= ((locs zip exps).zipWithIndex map {case ((locx, expx), i) => + get_flip(s.loc.tpe, i, Default) match { + case Default => WGeq(width_BANG(locx), width_BANG(expx)) + case Flip => WGeq(width_BANG(expx), width_BANG(locx)) + } + }) + case (s: PartialConnect) => + val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default) + val locs = create_exps(s.loc) + val exps = create_exps(s.expr) + v ++= (ls map {case (x, y) => + val locx = locs(x) + val expx = exps(y) + get_flip(s.loc.tpe, x, Default) match { + case Default => WGeq(width_BANG(locx), width_BANG(expx)) + case Flip => WGeq(width_BANG(expx), width_BANG(locx)) + } + }) + case (s:DefRegister) => v ++= (Seq( + WGeq(width_BANG(s.reset), IntWidth(1)), + WGeq(IntWidth(1), width_BANG(s.reset)) + ) ++ get_constraints_t(s.tpe, s.init.tpe, Default)) + case (s:Conditionally) => v ++= Seq( + WGeq(width_BANG(s.pred), IntWidth(1)), + WGeq(IntWidth(1), width_BANG(s.pred)) + ) + case _ => + } + s map get_constraints_e map get_constraints_s + } + + c.modules foreach (_ map get_constraints_s) + + //println-debug("======== ALL CONSTRAINTS ========") + //for x in v do : println-debug(x) + //println-debug("=================================") + val h = solve_constraints(v) + //println-debug("======== SOLVED CONSTRAINTS ========") + //for x in h do : println-debug(x) + //println-debug("====================================") + + def evaluate(w: Width): Width = { + def map2(a: Option[BigInt], b: Option[BigInt], f: (BigInt,BigInt) => BigInt): Option[BigInt] = + for (a_num <- a; b_num <- b) yield f(a_num, b_num) + def reduceOptions(l: Seq[Option[BigInt]], f: (BigInt,BigInt) => BigInt): Option[BigInt] = + l.reduce(map2(_, _, f)) + + // This function shouldn't be necessary + // Added as protection in case a constraint accidentally uses MinWidth/MaxWidth + // without any actual Widths. This should be elevated to an earlier error + def forceNonEmpty(in: Seq[Option[BigInt]], default: Option[BigInt]): Seq[Option[BigInt]] = + if (in.isEmpty) Seq(default) + else in + + def solve(w: Width): Option[BigInt] = w match { + case (w: VarWidth) => + for{ + v <- h.get(w.name) if !v.isInstanceOf[VarWidth] + result <- solve(v) + } yield result + case (w: MaxWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), Some(BigInt(0))), max) + case (w: MinWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), None), min) + case (w: PlusWidth) => map2(solve(w.arg1), solve(w.arg2), {_ + _}) + case (w: MinusWidth) => map2(solve(w.arg1), solve(w.arg2), {_ - _}) + case (w: ExpWidth) => map2(Some(BigInt(2)), solve(w.arg1), pow_minus_one) + case (w: IntWidth) => Some(w.width) + case (w) => println(w); error("Shouldn't be here"); None; + } + + solve(w) match { + case None => w + case Some(s) => IntWidth(s) + } + } + + def reduce_var_widths_w(w: Width): Width = { + //println-all-debug(["REPLACE: " w]) + evaluate(w) + //println-all-debug(["WITH: " wx]) + } + + def reduce_var_widths_t(t: Type): Type = { + t map reduce_var_widths_t map reduce_var_widths_w + } + + def reduce_var_widths_s(s: Statement): Statement = { + s map reduce_var_widths_s map reduce_var_widths_t + } + + def reduce_var_widths_p(p: Port): Port = { + Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe)) + } + + InferTypes.run(c.copy(modules = c.modules map (_ + map reduce_var_widths_p + map reduce_var_widths_s))) + } +} diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index c143212e..b9808485 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -94,513 +94,6 @@ object ToWorkingIR extends Pass { } } -object ResolveKinds extends Pass { - private var mname = "" - def name = "Resolve Kinds" - def run (c:Circuit): Circuit = { - def resolve_kinds (m:DefModule, c:Circuit):DefModule = { - val kinds = LinkedHashMap[String,Kind]() - def resolve (body:Statement) = { - def resolve_expr (e:Expression):Expression = { - e match { - case e:WRef => WRef(e.name,e.tpe,kinds(e.name),e.gender) - case e => e map (resolve_expr) - } - } - def resolve_stmt (s:Statement):Statement = s map (resolve_stmt) map (resolve_expr) - resolve_stmt(body) - } - - def find (m:DefModule) = { - def find_stmt (s:Statement):Statement = { - s match { - case s:DefWire => kinds(s.name) = WireKind() - case s:DefNode => kinds(s.name) = NodeKind() - case s:DefRegister => kinds(s.name) = RegKind() - case s:WDefInstance => kinds(s.name) = InstanceKind() - case s:DefMemory => kinds(s.name) = MemKind(s.readers ++ s.writers ++ s.readwriters) - case s => false - } - s map (find_stmt) - } - m.ports.foreach { p => kinds(p.name) = PortKind() } - m match { - case m:Module => find_stmt(m.body) - case m:ExtModule => false - } - } - - mname = m.name - find(m) - m match { - case m:Module => { - val bodyx = resolve(m.body) - Module(m.info,m.name,m.ports,bodyx) - } - case m:ExtModule => ExtModule(m.info,m.name,m.ports) - } - } - val modulesx = c.modules.map(m => resolve_kinds(m,c)) - Circuit(c.info,modulesx,c.main) - } -} - -object InferTypes extends Pass { - private var mname = "" - def name = "Infer Types" - def set_type (s:Statement, t:Type) : Statement = { - s match { - case s:DefWire => DefWire(s.info,s.name,t) - case s:DefRegister => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) - case s:DefMemory => DefMemory(s.info,s.name,t,s.depth,s.writeLatency,s.readLatency,s.readers,s.writers,s.readwriters) - case s:DefNode => s - } - } - def remove_unknowns_w (w:Width)(implicit namespace: Namespace):Width = { - w match { - case UnknownWidth => VarWidth(namespace.newName("w")) - case w => w - } - } - def remove_unknowns (t:Type)(implicit n: Namespace): Type = mapr(remove_unknowns_w _,t) - def run (c:Circuit): Circuit = { - val module_types = LinkedHashMap[String,Type]() - implicit val wnamespace = Namespace() - def infer_types (m:DefModule) : DefModule = { - val types = LinkedHashMap[String,Type]() - def infer_types_e (e:Expression) : Expression = { - e map (infer_types_e) match { - case e:ValidIf => ValidIf(e.cond,e.value,e.value.tpe) - case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender) - case e:WSubField => WSubField(e.exp,e.name,field_type(e.exp.tpe,e.name),e.gender) - case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(e.exp.tpe),e.gender) - case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(e.exp.tpe),e.gender) - case e:DoPrim => set_primop_type(e) - case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval)) - case e:UIntLiteral => e - case e:SIntLiteral => e - } - } - def infer_types_s (s:Statement) : Statement = { - s match { - case s:DefRegister => { - val t = remove_unknowns(get_type(s)) - types(s.name) = t - set_type(s,t) map (infer_types_e) - } - case s:DefWire => { - val sx = s map(infer_types_e) - val t = remove_unknowns(get_type(sx)) - types(s.name) = t - set_type(sx,t) - } - case s:DefNode => { - val sx = s map (infer_types_e) - val t = remove_unknowns(get_type(sx)) - types(s.name) = t - set_type(sx,t) - } - case s:DefMemory => { - val t = remove_unknowns(get_type(s)) - types(s.name) = t - val dt = remove_unknowns(s.dataType) - set_type(s,dt) - } - case s:WDefInstance => { - types(s.name) = module_types(s.module) - WDefInstance(s.info,s.name,s.module,module_types(s.module)) - } - case s => s map (infer_types_s) map (infer_types_e) - } - } - - mname = m.name - m.ports.foreach(p => types(p.name) = p.tpe) - m match { - case m:Module => Module(m.info,m.name,m.ports,infer_types_s(m.body)) - case m:ExtModule => m - } - } - - val modulesx = c.modules.map { - m => { - mname = m.name - val portsx = m.ports.map(p => Port(p.info,p.name,p.direction,remove_unknowns(p.tpe))) - m match { - case m:Module => Module(m.info,m.name,portsx,m.body) - case m:ExtModule => ExtModule(m.info,m.name,portsx) - } - } - } - modulesx.foreach(m => module_types(m.name) = module_type(m)) - Circuit(c.info,modulesx.map({m => mname = m.name; infer_types(m)}) , c.main ) - } -} - -object ResolveGenders extends Pass { - private var mname = "" - def name = "Resolve Genders" - def run (c:Circuit): Circuit = { - def resolve_e (g:Gender)(e:Expression) : Expression = { - e match { - case e:WRef => WRef(e.name,e.tpe,e.kind,g) - case e:WSubField => { - val expx = - field_flip(e.exp.tpe,e.name) match { - case Default => resolve_e(g)(e.exp) - case Flip => resolve_e(swap(g))(e.exp) - } - WSubField(expx,e.name,e.tpe,g) - } - case e:WSubIndex => { - val expx = resolve_e(g)(e.exp) - WSubIndex(expx,e.value,e.tpe,g) - } - case e:WSubAccess => { - val expx = resolve_e(g)(e.exp) - val indexx = resolve_e(MALE)(e.index) - WSubAccess(expx,indexx,e.tpe,g) - } - case e => e map (resolve_e(g)) - } - } - - def resolve_s (s:Statement) : Statement = { - s match { - case s:IsInvalid => { - val expx = resolve_e(FEMALE)(s.expr) - IsInvalid(s.info,expx) - } - case s:Connect => { - val locx = resolve_e(FEMALE)(s.loc) - val expx = resolve_e(MALE)(s.expr) - Connect(s.info,locx,expx) - } - case s:PartialConnect => { - val locx = resolve_e(FEMALE)(s.loc) - val expx = resolve_e(MALE)(s.expr) - PartialConnect(s.info,locx,expx) - } - case s => s map (resolve_e(MALE)) map (resolve_s) - } - } - val modulesx = c.modules.map { - m => { - mname = m.name - m match { - case m:Module => { - val bodyx = resolve_s(m.body) - Module(m.info,m.name,m.ports,bodyx) - } - case m:ExtModule => m - } - } - } - Circuit(c.info,modulesx,c.main) - } -} - -object InferWidths extends Pass { - def name = "Infer Widths" - var mname = "" - def solve_constraints (l:Seq[WGeq]) : LinkedHashMap[String,Width] = { - def unique (ls:Seq[Width]) : Seq[Width] = ls.map(w => new WrappedWidth(w)).distinct.map(_.w) - def make_unique (ls:Seq[WGeq]) : LinkedHashMap[String,Width] = { - val h = LinkedHashMap[String,Width]() - for (g <- ls) { - (g.loc) match { - case (w:VarWidth) => { - val n = w.name - if (h.contains(n)) h(n) = MaxWidth(Seq(g.exp,h(n))) else h(n) = g.exp - } - case (w) => w - } - } - h - } - def simplify (w:Width) : Width = { - (w map (simplify)) match { - case (w:MinWidth) => { - val v = ArrayBuffer[Width]() - for (wx <- w.args) { - (wx) match { - case (wx:MinWidth) => for (x <- wx.args) { v += x } - case (wx) => v += wx } } - MinWidth(unique(v)) } - case (w:MaxWidth) => { - val v = ArrayBuffer[Width]() - for (wx <- w.args) { - (wx) match { - case (wx:MaxWidth) => for (x <- wx.args) { v += x } - case (wx) => v += wx } } - MaxWidth(unique(v)) } - case (w:PlusWidth) => { - (w.arg1,w.arg2) match { - case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width + w2.width) - case (w1,w2) => w }} - case (w:MinusWidth) => { - (w.arg1,w.arg2) match { - case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width - w2.width) - case (w1,w2) => w }} - case (w:ExpWidth) => { - (w.arg1) match { - case (w1:IntWidth) => IntWidth(BigInt((scala.math.pow(2,w1.width.toDouble) - 1).toLong)) - case (w1) => w }} - case (w) => w } } - def substitute (h:LinkedHashMap[String,Width])(w:Width) : Width = { - //;println-all-debug(["Substituting for [" w "]"]) - val wx = simplify(w) - //;println-all-debug(["After Simplify: [" wx "]"]) - (simplify(w) map (substitute(h))) match { - case (w:VarWidth) => { - //;("matched println-debugvarwidth!") - if (h.contains(w.name)) { - //;println-debug("Contained!") - //;println-all-debug(["Width: " w]) - //;println-all-debug(["Accessed: " h[name(w)]]) - val t = simplify(substitute(h)(h(w.name))) - //;val t = h[name(w)] - //;println-all-debug(["Width after sub: " t]) - h(w.name) = t - t - } else w - } - case (w) => w - //;println-all-debug(["not varwidth!" w]) - } - } - def b_sub (h:LinkedHashMap[String,Width])(w:Width) : Width = { - (w map (b_sub(h))) match { - case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w - case (w) => w - } - } - def remove_cycle (n:String)(w:Width) : Width = { - //;println-all-debug(["Removing cycle for " n " inside " w]) - val wx = (w map (remove_cycle(n))) match { - case (w:MaxWidth) => MaxWidth(w.args.filter{ w => { - w match { - case (w:VarWidth) => !(n equals w.name) - case (w) => true - }}}) - case (w:MinusWidth) => { - w.arg1 match { - case (v:VarWidth) => if (n == v.name) v else w - case (v) => w }} - case (w) => w - } - //;println-all-debug(["After removing cycle for " n ", returning " wx]) - wx - } - def self_rec (n:String,w:Width) : Boolean = { - var has = false - def look (w:Width) : Width = { - (w map (look)) match { - case (w:VarWidth) => if (w.name == n) has = true - case (w) => w } - w } - look(w) - has } - - //; Forward solve - //; Returns a solved list where each constraint undergoes: - //; 1) Continuous Solving (using triangular solving) - //; 2) Remove Cycles - //; 3) Move to solved if not self-recursive - val u = make_unique(l) - - //println("======== UNIQUE CONSTRAINTS ========") - //for (x <- u) { println(x) } - //println("====================================") - - - val f = LinkedHashMap[String,Width]() - val o = ArrayBuffer[String]() - for (x <- u) { - //println("==== SOLUTIONS TABLE ====") - //for (x <- f) println(x) - //println("=========================") - - val (n, e) = (x._1, x._2) - val e_sub = substitute(f)(e) - - //println("Solving " + n + " => " + e) - //println("After Substitute: " + n + " => " + e_sub) - //println("==== SOLUTIONS TABLE (Post Substitute) ====") - //for (x <- f) println(x) - //println("=========================") - - val ex = remove_cycle(n)(e_sub) - - //println("After Remove Cycle: " + n + " => " + ex) - if (!self_rec(n,ex)) { - //println("Not rec!: " + n + " => " + ex) - //println("Adding [" + n + "=>" + ex + "] to Solutions Table") - o += n - f(n) = ex - } - } - - //println("Forward Solved Constraints") - //for (x <- f) println(x) - - //; Backwards Solve - val b = LinkedHashMap[String,Width]() - for (i <- 0 until o.size) { - val n = o(o.size - 1 - i) - /* - println("SOLVE BACK: [" + n + " => " + f(n) + "]") - println("==== SOLUTIONS TABLE ====") - for (x <- b) println(x) - println("=========================") - */ - val ex = simplify(b_sub(b)(f(n))) - /* - println("BACK RETURN: [" + n + " => " + ex + "]") - */ - b(n) = ex - /* - println("==== SOLUTIONS TABLE (Post backsolve) ====") - for (x <- b) println(x) - println("=========================") - */ - } - b - } - - def width_BANG (t:Type) : Width = { - (t) match { - case (t:UIntType) => t.width - case (t:SIntType) => t.width - case ClockType => IntWidth(1) - case (t) => error("No width!"); IntWidth(-1) } } - def width_BANG (e:Expression) : Width = width_BANG(e.tpe) - - def reduce_var_widths(c: Circuit, h: LinkedHashMap[String,Width]): Circuit = { - def evaluate(w: Width): Width = { - def map2(a: Option[BigInt], b: Option[BigInt], f: (BigInt,BigInt) => BigInt): Option[BigInt] = - for (a_num <- a; b_num <- b) yield f(a_num, b_num) - def reduceOptions(l: Seq[Option[BigInt]], f: (BigInt,BigInt) => BigInt): Option[BigInt] = - l.reduce(map2(_, _, f)) - - // This function shouldn't be necessary - // Added as protection in case a constraint accidentally uses MinWidth/MaxWidth - // without any actual Widths. This should be elevated to an earlier error - def forceNonEmpty(in: Seq[Option[BigInt]], default: Option[BigInt]): Seq[Option[BigInt]] = - if(in.isEmpty) Seq(default) - else in - - - def solve(w: Width): Option[BigInt] = w match { - case (w: VarWidth) => - for{ - v <- h.get(w.name) if !v.isInstanceOf[VarWidth] - result <- solve(v) - } yield result - case (w: MaxWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), Some(BigInt(0))), max) - case (w: MinWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), None), min) - case (w: PlusWidth) => map2(solve(w.arg1), solve(w.arg2), {_ + _}) - case (w: MinusWidth) => map2(solve(w.arg1), solve(w.arg2), {_ - _}) - case (w: ExpWidth) => map2(Some(BigInt(2)), solve(w.arg1), pow_minus_one) - case (w: IntWidth) => Some(w.width) - case (w) => println(w); error("Shouldn't be here"); None; - } - - val s = solve(w) - (s) match { - case Some(s) => IntWidth(s) - case (s) => w - } - } - - def reduce_var_widths_w (w:Width) : Width = { - //println-all-debug(["REPLACE: " w]) - val wx = evaluate(w) - //println-all-debug(["WITH: " wx]) - wx - } - def reduce_var_widths_s (s: Statement): Statement = { - def onType(t: Type): Type = t map onType map reduce_var_widths_w - s map reduce_var_widths_s map onType - } - - val modulesx = c.modules.map{ m => { - val portsx = m.ports.map{ p => { - Port(p.info,p.name,p.direction,mapr(reduce_var_widths_w _,p.tpe)) }} - (m) match { - case (m:ExtModule) => ExtModule(m.info,m.name,portsx) - case (m:Module) => - mname = m.name - Module(m.info,m.name,portsx,m.body map reduce_var_widths_s _) }}} - InferTypes.run(Circuit(c.info,modulesx,c.main)) - } - - def run (c:Circuit): Circuit = { - val v = ArrayBuffer[WGeq]() - def constrain (w1:Width,w2:Width) : Unit = v += WGeq(w1,w2) - def get_constraints_t (t1:Type,t2:Type,f:Orientation) : Unit = { - (t1,t2) match { - case (t1:UIntType,t2:UIntType) => constrain(t1.width,t2.width) - case (t1:SIntType,t2:SIntType) => constrain(t1.width,t2.width) - case (t1:BundleType,t2:BundleType) => { - (t1.fields,t2.fields).zipped.foreach{ (f1,f2) => { - get_constraints_t(f1.tpe,f2.tpe,times(f1.flip,f)) }}} - case (t1:VectorType,t2:VectorType) => get_constraints_t(t1.tpe,t2.tpe,f) }} - def get_constraints_e (e:Expression) : Expression = { - (e map (get_constraints_e)) match { - case (e:Mux) => { - constrain(width_BANG(e.cond),IntWidth(1)) - constrain(IntWidth(1),width_BANG(e.cond)) - e } - case (e) => e }} - def get_constraints (s:Statement) : Statement = { - (s map (get_constraints_e)) match { - case (s:Connect) => { - val n = get_size(s.loc.tpe) - val ce_loc = create_exps(s.loc) - val ce_exp = create_exps(s.expr) - for (i <- 0 until n) { - val locx = ce_loc(i) - val expx = ce_exp(i) - get_flip(s.loc.tpe,i,Default) match { - case Default => constrain(width_BANG(locx),width_BANG(expx)) - case Flip => constrain(width_BANG(expx),width_BANG(locx)) }} - s } - case (s:PartialConnect) => { - val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) - for (x <- ls) { - val locx = create_exps(s.loc)(x._1) - val expx = create_exps(s.expr)(x._2) - get_flip(s.loc.tpe,x._1,Default) match { - case Default => constrain(width_BANG(locx),width_BANG(expx)) - case Flip => constrain(width_BANG(expx),width_BANG(locx)) }} - s } - case (s:DefRegister) => { - constrain(width_BANG(s.reset),IntWidth(1)) - constrain(IntWidth(1),width_BANG(s.reset)) - get_constraints_t(s.tpe,s.init.tpe,Default) - s } - case (s:Conditionally) => { - v += WGeq(width_BANG(s.pred),IntWidth(1)) - v += WGeq(IntWidth(1),width_BANG(s.pred)) - s map (get_constraints) } - case (s) => s map (get_constraints) }} - - for (m <- c.modules) { - (m) match { - case (m:Module) => mname = m.name; get_constraints(m.body) - case (m) => false }} - //println-debug("======== ALL CONSTRAINTS ========") - //for x in v do : println-debug(x) - //println-debug("=================================") - val h = solve_constraints(v) - //println-debug("======== SOLVED CONSTRAINTS ========") - //for x in h do : println-debug(x) - //println-debug("====================================") - reduce_var_widths(Circuit(c.info,c.modules,c.main),h) - } -} - object PullMuxes extends Pass { def name = "Pull Muxes" def run(c: Circuit): Circuit = { @@ -837,437 +330,3 @@ object VerilogRename extends Pass { Circuit(c.info,modulesx,c.main) } } - -object CInferTypes extends Pass { - def name = "CInfer Types" - var mname = "" - def set_type (s:Statement, t:Type) : Statement = { - (s) match { - case (s:DefWire) => DefWire(s.info,s.name,t) - case (s:DefRegister) => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) - case (s:CDefMemory) => CDefMemory(s.info,s.name,t,s.size,s.seq) - case (s:CDefMPort) => CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) - case (s:DefNode) => s - } - } - - def to_field (p:Port) : Field = { - if (p.direction == Output) Field(p.name,Default,p.tpe) - else if (p.direction == Input) Field(p.name,Flip,p.tpe) - else error("Shouldn't be here"); Field(p.name,Flip,p.tpe) - } - def module_type (m:DefModule) : Type = BundleType(m.ports.map(p => to_field(p))) - def field_type (v:Type,s:String) : Type = { - (v) match { - case (v:BundleType) => { - val ft = v.fields.find(p => p.name == s) - if (ft != None) ft.get.tpe - else UnknownType - } - case (v) => UnknownType - } - } - def sub_type (v:Type) : Type = - (v) match { - case (v:VectorType) => v.tpe - case (v) => UnknownType - } - def run (c:Circuit) : Circuit = { - val module_types = LinkedHashMap[String,Type]() - def infer_types (m:DefModule) : DefModule = { - val types = LinkedHashMap[String,Type]() - def infer_types_e (e:Expression) : Expression = { - e map infer_types_e match { - case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType)) - case (e:SubField) => SubField(e.expr,e.name,field_type(e.expr.tpe,e.name)) - case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(e.expr.tpe)) - case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(e.expr.tpe)) - case (e:DoPrim) => set_primop_type(e) - case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval)) - case (e:ValidIf) => ValidIf(e.cond,e.value,e.value.tpe) - case (_:UIntLiteral | _:SIntLiteral) => e - } - } - def infer_types_s (s:Statement) : Statement = { - s match { - case (s:DefRegister) => { - types(s.name) = s.tpe - s map infer_types_e - s - } - case (s:DefWire) => { - types(s.name) = s.tpe - s - } - case (s:DefNode) => { - val sx = s map infer_types_e - val t = get_type(sx) - types(s.name) = t - sx - } - case (s:DefMemory) => { - types(s.name) = get_type(s) - s - } - case (s:CDefMPort) => { - val t = types.getOrElse(s.mem,UnknownType) - types(s.name) = t - CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) - } - case (s:CDefMemory) => { - types(s.name) = s.tpe - s - } - case (s:DefInstance) => { - types(s.name) = module_types.getOrElse(s.module,UnknownType) - s - } - case (s) => s map infer_types_s map infer_types_e - } - } - for (p <- m.ports) { - types(p.name) = p.tpe - } - m match { - case (m:Module) => Module(m.info,m.name,m.ports,infer_types_s(m.body)) - case (m:ExtModule) => m - } - } - - //; MAIN - for (m <- c.modules) { - module_types(m.name) = module_type(m) - } - val modulesx = c.modules.map(m => infer_types(m)) - Circuit(c.info, modulesx, c.main) - } -} - -object CInferMDir extends Pass { - def name = "CInfer MDir" - var mname = "" - def run (c:Circuit) : Circuit = { - def infer_mdir (m:DefModule) : DefModule = { - val mports = LinkedHashMap[String,MPortDir]() - def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = { - (e map (infer_mdir_e(dir))) match { - case (e:Reference) => { - if (mports.contains(e.name)) { - val new_mport_dir = { - (mports(e.name),dir) match { - case (MInfer,MInfer) => error("Shouldn't be here") - case (MInfer,MWrite) => MWrite - case (MInfer,MRead) => MRead - case (MInfer,MReadWrite) => MReadWrite - case (MWrite,MInfer) => error("Shouldn't be here") - case (MWrite,MWrite) => MWrite - case (MWrite,MRead) => MReadWrite - case (MWrite,MReadWrite) => MReadWrite - case (MRead,MInfer) => error("Shouldn't be here") - case (MRead,MWrite) => MReadWrite - case (MRead,MRead) => MRead - case (MRead,MReadWrite) => MReadWrite - case (MReadWrite,MInfer) => error("Shouldn't be here") - case (MReadWrite,MWrite) => MReadWrite - case (MReadWrite,MRead) => MReadWrite - case (MReadWrite,MReadWrite) => MReadWrite - } - } - mports(e.name) = new_mport_dir - } - e - } - case (e) => e - } - } - def infer_mdir_s (s:Statement) : Statement = { - (s) match { - case (s:CDefMPort) => { - mports(s.name) = s.direction - s map (infer_mdir_e(MRead)) - } - case (s:Connect) => { - infer_mdir_e(MRead)(s.expr) - infer_mdir_e(MWrite)(s.loc) - s - } - case (s:PartialConnect) => { - infer_mdir_e(MRead)(s.expr) - infer_mdir_e(MWrite)(s.loc) - s - } - case (s) => s map (infer_mdir_s) map (infer_mdir_e(MRead)) - } - } - def set_mdir_s (s:Statement) : Statement = { - (s) match { - case (s:CDefMPort) => - CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name)) - case (s) => s map (set_mdir_s) - } - } - (m) match { - case (m:Module) => { - infer_mdir_s(m.body) - Module(m.info,m.name,m.ports,set_mdir_s(m.body)) - } - case (m:ExtModule) => m - } - } - - //; MAIN - Circuit(c.info, c.modules.map(m => infer_mdir(m)), c.main) - } -} - -case class MPort( val name : String, val clk : Expression) -case class MPorts( val readers : ArrayBuffer[MPort], val writers : ArrayBuffer[MPort], val readwriters : ArrayBuffer[MPort]) -case class DataRef( val exp : Expression, val male : String, val female : String, val mask : String, val rdwrite : Boolean) - -object RemoveCHIRRTL extends Pass { - def name = "Remove CHIRRTL" - var mname = "" - def create_exps (e:Expression) : Seq[Expression] = e match { - case (e:Mux) => - val e1s = create_exps(e.tval) - val e2s = create_exps(e.fval) - (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2))) - case (e:ValidIf) => - create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe)) - case (e) => (e.tpe) match { - case (_:GroundType) => Seq(e) - case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => - exps ++ create_exps(SubField(e,f.name,f.tpe))) - case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => - exps ++ create_exps(SubIndex(e,i,t.tpe))) - case UnknownType => Seq(e) - } - } - def run (c:Circuit) : Circuit = { - def remove_chirrtl_m (m:Module) : Module = { - val hash = LinkedHashMap[String,MPorts]() - val repl = LinkedHashMap[String,DataRef]() - val raddrs = HashMap[String, Expression]() - val ut = UnknownType - val mport_types = LinkedHashMap[String,Type]() - val smems = HashSet[String]() - def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]()) - def collect_smems_and_mports (s:Statement) : Statement = { - (s) match { - case (s:CDefMemory) if s.seq => - smems += s.name - s - case (s:CDefMPort) => { - val mports = hash.getOrElse(s.mem,EMPs()) - s.direction match { - case MRead => mports.readers += MPort(s.name,s.exps(1)) - case MWrite => mports.writers += MPort(s.name,s.exps(1)) - case MReadWrite => mports.readwriters += MPort(s.name,s.exps(1)) - } - hash(s.mem) = mports - s - } - case (s) => s map (collect_smems_and_mports) - } - } - def collect_refs (s:Statement) : Statement = { - (s) match { - case (s:CDefMemory) => { - mport_types(s.name) = s.tpe - val stmts = ArrayBuffer[Statement]() - val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size)))) - val tdata = s.tpe - def set_poison (vec:Seq[MPort],addr:String) : Unit = { - for (r <- vec ) { - stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),addr,taddr)) - stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),"clk",taddr)) - } - } - def set_enable (vec:Seq[MPort],en:String) : Unit = { - for (r <- vec ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),en,taddr),zero) - }} - def set_wmode (vec:Seq[MPort],wmode:String) : Unit = { - for (r <- vec) { - stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),wmode,taddr),zero) - }} - def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = { - val tmask = create_mask(s.tpe) - for (r <- vec ) { - stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),data,tdata)) - for (x <- create_exps(SubField(SubField(Reference(s.name,ut),r.name,ut),mask,tmask)) ) { - stmts += Connect(s.info,x,zero) - }}} - val rds = (hash.getOrElse(s.name,EMPs())).readers - set_poison(rds,"addr") - set_enable(rds,"en") - val wrs = (hash.getOrElse(s.name,EMPs())).writers - set_poison(wrs,"addr") - set_enable(wrs,"en") - set_write(wrs,"data","mask") - val rws = (hash.getOrElse(s.name,EMPs())).readwriters - set_poison(rws,"addr") - set_wmode(rws,"wmode") - set_enable(rws,"en") - set_write(rws,"wdata","wmask") - val read_l = if (s.seq) 1 else 0 - val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name)) - Block(Seq(mem,Block(stmts))) - } - case (s:CDefMPort) => { - mport_types(s.name) = mport_types(s.mem) - val addrs = ArrayBuffer[String]() - val clks = ArrayBuffer[String]() - val ens = ArrayBuffer[String]() - val masks = ArrayBuffer[String]() - s.direction match { - case MReadWrite => { - repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"rdata","wdata","wmask",true) - addrs += "addr" - clks += "clk" - ens += "en" - masks += "wmask" - } - case MWrite => { - repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","mask",false) - addrs += "addr" - clks += "clk" - ens += "en" - masks += "mask" - } - case MRead => { - repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","blah",false) - addrs += "addr" - clks += "clk" - s.exps(0) match { - case e: Reference if smems(s.mem) => - raddrs(e.name) = SubField(SubField(Reference(s.mem,ut),s.name,ut),"en",ut) - case _ => ens += "en" - } - } - } - val stmts = ArrayBuffer[Statement]() - for (x <- addrs ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(0)) - } - for (x <- clks ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(1)) - } - for (x <- ens ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),one) - } - Block(stmts) - } - case (s) => s map (collect_refs) - } - } - def remove_chirrtl_s (s:Statement) : Statement = { - var has_write_mport = false - var has_read_mport: Option[Expression] = None - var has_readwrite_mport: Option[Expression] = None - def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = { - (e) match { - case (e:Reference) if repl contains e.name => - val vt = repl(e.name) - g match { - case MALE => SubField(vt.exp,vt.male,e.tpe) - case FEMALE => { - has_write_mport = true - if (vt.rdwrite) - has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1)))) - SubField(vt.exp,vt.female,e.tpe) - } - } - case (e:Reference) if g == FEMALE && (raddrs contains e.name) => - has_read_mport = Some(raddrs(e.name)) - e - case (e:Reference) => e - case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.expr),remove_chirrtl_e(MALE)(e.index),e.tpe) - case (e) => e map (remove_chirrtl_e(g)) - } - } - def get_mask (e:Expression) : Expression = { - (e map (get_mask)) match { - case (e:Reference) => { - if (repl.contains(e.name)) { - val vt = repl(e.name) - val t = create_mask(e.tpe) - SubField(vt.exp,vt.mask,t) - } else e - } - case (e) => e - } - } - (s) match { - case (s:DefNode) => { - val stmts = ArrayBuffer[Statement]() - val valuex = remove_chirrtl_e(MALE)(s.value) - stmts += DefNode(s.info,s.name,valuex) - has_read_mport match { - case None => - case Some(en) => stmts += Connect(s.info,en,one) - } - if (stmts.size > 1) Block(stmts) - else stmts(0) - } - case (s:Connect) => { - val stmts = ArrayBuffer[Statement]() - val rocx = remove_chirrtl_e(MALE)(s.expr) - val locx = remove_chirrtl_e(FEMALE)(s.loc) - stmts += Connect(s.info,locx,rocx) - has_read_mport match { - case None => - case Some(en) => stmts += Connect(s.info,en,one) - } - if (has_write_mport) { - val e = get_mask(s.loc) - for (x <- create_exps(e) ) { - stmts += Connect(s.info,x,one) - } - has_readwrite_mport match { - case None => - case Some(wmode) => stmts += Connect(s.info,wmode,one) - } - } - if (stmts.size > 1) Block(stmts) - else stmts(0) - } - case (s:PartialConnect) => { - val stmts = ArrayBuffer[Statement]() - val locx = remove_chirrtl_e(FEMALE)(s.loc) - val rocx = remove_chirrtl_e(MALE)(s.expr) - stmts += PartialConnect(s.info,locx,rocx) - has_read_mport match { - case None => - case Some(en) => stmts += Connect(s.info,en,one) - } - if (has_write_mport) { - val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) - val locs = create_exps(get_mask(s.loc)) - for (x <- ls ) { - val locx = locs(x._1) - stmts += Connect(s.info,locx,one) - } - has_readwrite_mport match { - case None => - case Some(wmode) => stmts += Connect(s.info,wmode,one) - } - } - if (stmts.size > 1) Block(stmts) - else stmts(0) - } - case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE)) - } - } - collect_smems_and_mports(m.body) - val sx = collect_refs(m.body) - Module(m.info,m.name, m.ports, remove_chirrtl_s(sx)) - } - val modulesx = c.modules.map{ m => { - (m) match { - case (m:Module) => remove_chirrtl_m(m) - case (m:ExtModule) => m - }}} - Circuit(c.info,modulesx, c.main) - } -} diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index 880d6b1c..08f08eac 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -1,11 +1,11 @@ package firrtl.passes +import firrtl.{WRef, WSubAccess, WSubIndex, WSubField, Namespace} +import firrtl.PrimOps.{And, Eq} import firrtl.ir._ -import firrtl.{WRef, WSubAccess, WSubIndex, WSubField} import firrtl.Mappers._ import firrtl.Utils._ import firrtl.WrappedExpression._ -import firrtl.Namespace import scala.collection.mutable @@ -13,6 +13,13 @@ import scala.collection.mutable */ object RemoveAccesses extends Pass { def name = "Remove Accesses" + + private def AND(e1: Expression, e2: Expression) = + DoPrim(And, Seq(e1, e2), Nil, UIntType(IntWidth(1))) + + private def EQV(e1: Expression, e2: Expression): Expression = + DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe) + /** Container for a base expression and its corresponding guard */ private case class Location(base: Expression, guard: Expression) @@ -53,13 +60,13 @@ object RemoveAccesses extends Pass { /** Returns true if e contains a [[firrtl.WSubAccess]] */ private def hasAccess(e: Expression): Boolean = { - var ret: Boolean = false - def rec_has_access(e: Expression): Expression = { - e match { - case e : WSubAccess => ret = true - case e => - } - e map rec_has_access + var ret: Boolean = false + def rec_has_access(e: Expression): Expression = { + e match { + case e : WSubAccess => ret = true + case e => + } + e map rec_has_access } rec_has_access(e) ret @@ -150,10 +157,9 @@ object RemoveAccesses extends Pass { Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body))) } - val newModules = c.modules.map { + c copy (modules = (c.modules map { case m: ExtModule => m case m: Module => remove_m(m) - } - Circuit(c.info, newModules, c.main) + })) } } diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala new file mode 100644 index 00000000..2bae92a7 --- /dev/null +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -0,0 +1,256 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtl.passes + +// Datastructures +import scala.collection.mutable.ArrayBuffer + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ + +case class MPort(name: String, clk: Expression) +case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort]) +case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean) + +object RemoveCHIRRTL extends Pass { + def name = "Remove CHIRRTL" + + val ut = UnknownType + type MPortMap = collection.mutable.LinkedHashMap[String, MPorts] + type SeqMemSet = collection.mutable.HashSet[String] + type MPortTypeMap = collection.mutable.LinkedHashMap[String, Type] + type DataRefMap = collection.mutable.LinkedHashMap[String, DataRef] + type AddrMap = collection.mutable.HashMap[String, Expression] + + def create_exps(e: Expression): Seq[Expression] = e match { + case (e: Mux) => + val e1s = create_exps(e.tval) + val e2s = create_exps(e.fval) + (e1s zip e2s) map { case (e1, e2) => Mux(e.cond, e1, e2, mux_type(e1, e2)) } + case (e: ValidIf) => + create_exps(e.value) map (e1 => ValidIf(e.cond, e1, e1.tpe)) + case (e) => (e.tpe) match { + case (_: GroundType) => Seq(e) + case (t: BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => + exps ++ create_exps(SubField(e, f.name, f.tpe))) + case (t: VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + exps ++ create_exps(SubIndex(e, i, t.tpe))) + case UnknownType => Seq(e) + } + } + + private def EMPs: MPorts = MPorts(ArrayBuffer[MPort](), ArrayBuffer[MPort](), ArrayBuffer[MPort]()) + + def collect_smems_and_mports(mports: MPortMap, smems: SeqMemSet)(s: Statement): Statement = { + s match { + case (s:CDefMemory) if s.seq => smems += s.name + case (s:CDefMPort) => + val p = mports getOrElse (s.mem, EMPs) + s.direction match { + case MRead => p.readers += MPort(s.name,s.exps(1)) + case MWrite => p.writers += MPort(s.name,s.exps(1)) + case MReadWrite => p.readwriters += MPort(s.name,s.exps(1)) + } + mports(s.mem) = p + case s => + } + s map collect_smems_and_mports(mports, smems) + } + + def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap, + refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = s match { + case (s: CDefMemory) => + types(s.name) = s.tpe + val taddr = UIntType(IntWidth(math.max(1, ceil_log2(s.size)))) + val tdata = s.tpe + def set_poison(vec: Seq[MPort], addr: String) = vec flatMap (r => Seq( + IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), addr, taddr)), + IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "clk", taddr)) + )) + def set_enable(vec: Seq[MPort], en: String) = vec map (r => + Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), en, taddr), zero) + ) + def set_wmode (vec: Seq[MPort], wmode: String) = vec map (r => + Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), wmode, taddr), zero) + ) + def set_write (vec: Seq[MPort], data: String, mask: String) = vec flatMap {r => + val tmask = create_mask(s.tpe) + IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), data, tdata)) +: + (create_exps(SubField(SubField(Reference(s.name, ut), r.name, ut), mask, tmask)) + map (Connect(s.info, _, zero)) + ) + } + val rds = (mports getOrElse (s.name, EMPs)).readers + val wrs = (mports getOrElse (s.name, EMPs)).writers + val rws = (mports getOrElse (s.name, EMPs)).readwriters + val stmts = set_poison(rds, "addr") ++ + set_enable(rds, "en") ++ + set_poison(wrs, "addr") ++ + set_enable(wrs, "en") ++ + set_write(wrs, "data", "mask") ++ + set_poison(rws, "addr") ++ + set_wmode(rws, "wmode") ++ + set_enable(rws, "en") ++ + set_write(rws, "wdata", "wmask") + val mem = DefMemory(s.info, s.name, s.tpe, s.size, 1, if (s.seq) 1 else 0, + rds map (_.name), wrs map (_.name), rws map (_.name)) + Block(mem +: stmts) + case (s: CDefMPort) => { + types(s.name) = types(s.mem) + val addrs = ArrayBuffer[String]() + val clks = ArrayBuffer[String]() + val ens = ArrayBuffer[String]() + s.direction match { + case MReadWrite => + refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "rdata", "wdata", "wmask", true) + addrs += "addr" + clks += "clk" + ens += "en" + case MWrite => + refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "data", "data", "mask", false) + addrs += "addr" + clks += "clk" + ens += "en" + case MRead => + refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "data", "data", "blah", false) + addrs += "addr" + clks += "clk" + s.exps.head match { + case e: Reference if smems(s.mem) => + raddrs(e.name) = SubField(SubField(Reference(s.mem, ut), s.name, ut), "en", ut) + case _ => ens += "en" + } + } + Block( + (addrs map (x => Connect(s.info, SubField(SubField(Reference(s.mem, ut), s.name, ut), x, ut), s.exps(0)))) ++ + (clks map (x => Connect(s.info, SubField(SubField(Reference(s.mem, ut), s.name, ut), x, ut), s.exps(1)))) ++ + (ens map (x => Connect(s.info,SubField(SubField(Reference(s.mem,ut), s.name, ut), x, ut), one)))) + } + case (s) => s map collect_refs(mports, smems, types, refs, raddrs) + } + + def get_mask(refs: DataRefMap)(e: Expression): Expression = + e map get_mask(refs) match { + case e: Reference => refs get e.name match { + case None => e + case Some(p) => SubField(p.exp, p.mask, create_mask(e.tpe)) + } + case e => e + } + + def remove_chirrtl_s(refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = { + var has_write_mport = false + var has_readwrite_mport: Option[Expression] = None + var has_read_mport: Option[Expression] = None + def remove_chirrtl_e(g: Gender)(e: Expression): Expression = e match { + case Reference(name, tpe) => refs get name match { + case Some(p) => g match { + case FEMALE => + has_write_mport = true + if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", UIntType(IntWidth(1)))) + SubField(p.exp, p.female, tpe) + case MALE => + SubField(p.exp, p.male, tpe) + } + case None => g match { + case FEMALE => raddrs get name match { + case Some(en) => has_read_mport = Some(en) ; e + case None => e + } + case MALE => e + } + } + case SubAccess(expr, index, tpe) => SubAccess( + remove_chirrtl_e(g)(expr), remove_chirrtl_e(MALE)(index), tpe) + case e => e map remove_chirrtl_e(g) + } + (s) match { + case DefNode(info, name, value) => + val valuex = remove_chirrtl_e(MALE)(value) + val sx = DefNode(info, name, valuex) + has_read_mport match { + case None => sx + case Some(en) => Block(Seq(sx, Connect(info, en, one))) + } + case Connect(info, loc, expr) => + val rocx = remove_chirrtl_e(MALE)(expr) + val locx = remove_chirrtl_e(FEMALE)(loc) + val sx = Connect(info, locx, rocx) + val stmts = ArrayBuffer[Statement]() + has_read_mport match { + case None => + case Some(en) => stmts += Connect(info, en, one) + } + if (has_write_mport) { + val locs = create_exps(get_mask(refs)(loc)) + stmts ++= (locs map (x => Connect(info, x, one))) + has_readwrite_mport match { + case None => + case Some(wmode) => stmts += Connect(info, wmode, one) + } + } + if (stmts.isEmpty) sx else Block(sx +: stmts) + case PartialConnect(info, loc, expr) => + val locx = remove_chirrtl_e(FEMALE)(loc) + val rocx = remove_chirrtl_e(MALE)(expr) + val sx = PartialConnect(info, locx, rocx) + val stmts = ArrayBuffer[Statement]() + has_read_mport match { + case None => + case Some(en) => stmts += Connect(info, en, one) + } + if (has_write_mport) { + val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default) + val locs = create_exps(get_mask(refs)(loc)) + stmts ++= (ls map { case (x, _) => Connect(info, locs(x), one) }) + has_readwrite_mport match { + case None => + case Some(wmode) => stmts += Connect(info, wmode, one) + } + } + if (stmts.isEmpty) sx else Block(sx +: stmts) + case s => s map remove_chirrtl_s(refs, raddrs) map remove_chirrtl_e(MALE) + } + } + + def remove_chirrtl_m(m: DefModule): DefModule = { + val mports = new MPortMap + val smems = new SeqMemSet + val types = new MPortTypeMap + val refs = new DataRefMap + val raddrs = new AddrMap + (m map collect_smems_and_mports(mports, smems) + map collect_refs(mports, smems, types, refs, raddrs) + map remove_chirrtl_s(refs, raddrs)) + } + + def run(c: Circuit): Circuit = + c copy (modules = (c.modules map remove_chirrtl_m)) +} diff --git a/src/main/scala/firrtl/passes/Resolves.scala b/src/main/scala/firrtl/passes/Resolves.scala new file mode 100644 index 00000000..3100f0c3 --- /dev/null +++ b/src/main/scala/firrtl/passes/Resolves.scala @@ -0,0 +1,163 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtl.passes + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ + +object ResolveKinds extends Pass { + def name = "Resolve Kinds" + type KindMap = collection.mutable.LinkedHashMap[String, Kind] + + def find_port(kinds: KindMap)(p: Port): Port = { + kinds(p.name) = PortKind() ; p + } + + def find_stmt(kinds: KindMap)(s: Statement):Statement = { + s match { + case s: DefWire => kinds(s.name) = WireKind() + case s: DefNode => kinds(s.name) = NodeKind() + case s: DefRegister => kinds(s.name) = RegKind() + case s: WDefInstance => kinds(s.name) = InstanceKind() + case s: DefMemory => kinds(s.name) = MemKind(s.readers ++ s.writers ++ s.readwriters) + case s => + } + s map find_stmt(kinds) + } + + def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match { + case e: WRef => e copy (kind = kinds(e.name)) + case e => e map resolve_expr(kinds) + } + + def resolve_stmt(kinds: KindMap)(s: Statement): Statement = + s map resolve_stmt(kinds) map resolve_expr(kinds) + + def resolve_kinds(m: DefModule): DefModule = { + val kinds = new KindMap + (m map find_port(kinds) + map find_stmt(kinds) + map resolve_stmt(kinds)) + } + + def run(c: Circuit): Circuit = + c copy (modules = (c.modules map resolve_kinds)) +} + +object ResolveGenders extends Pass { + def name = "Resolve Genders" + def resolve_e(g: Gender)(e: Expression): Expression = e match { + case e: WRef => e copy (gender = g) + case WSubField(exp, name, tpe, _) => WSubField( + Utils.field_flip(exp.tpe, name) match { + case Default => resolve_e(g)(exp) + case Flip => resolve_e(Utils.swap(g))(exp) + }, name, tpe, g) + case WSubIndex(exp, value, tpe, _) => + WSubIndex(resolve_e(g)(exp), value, tpe, g) + case WSubAccess(exp, index, tpe, _) => + WSubAccess(resolve_e(g)(exp), resolve_e(MALE)(index), tpe, g) + case e => e map resolve_e(g) + } + + def resolve_s(s: Statement): Statement = s match { + case IsInvalid(info, expr) => + IsInvalid(info, resolve_e(FEMALE)(expr)) + case Connect(info, loc, expr) => + Connect(info, resolve_e(FEMALE)(loc), resolve_e(MALE)(expr)) + case PartialConnect(info, loc, expr) => + PartialConnect(info, resolve_e(FEMALE)(loc), resolve_e(MALE)(expr)) + case s => s map resolve_e(MALE) map resolve_s + } + + def resolve_gender(m: DefModule): DefModule = m map resolve_s + + def run(c: Circuit): Circuit = + c copy (modules = (c.modules map resolve_gender)) +} + +object CInferMDir extends Pass { + def name = "CInfer MDir" + type MPortDirMap = collection.mutable.LinkedHashMap[String, MPortDir] + + def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = { + (e map infer_mdir_e(mports, dir)) match { + case e: Reference => mports get e.name match { + case Some(p) => mports(e.name) = (p, dir) match { + case (MInfer, MInfer) => Utils.error("Shouldn't be here") + case (MInfer, MWrite) => MWrite + case (MInfer, MRead) => MRead + case (MInfer, MReadWrite) => MReadWrite + case (MWrite, MInfer) => Utils.error("Shouldn't be here") + case (MWrite, MWrite) => MWrite + case (MWrite, MRead) => MReadWrite + case (MWrite, MReadWrite) => MReadWrite + case (MRead, MInfer) => Utils.error("Shouldn't be here") + case (MRead, MWrite) => MReadWrite + case (MRead, MRead) => MRead + case (MRead, MReadWrite) => MReadWrite + case (MReadWrite, MInfer) => Utils.error("Shouldn't be here") + case (MReadWrite, MWrite) => MReadWrite + case (MReadWrite, MRead) => MReadWrite + case (MReadWrite, MReadWrite) => MReadWrite + } ; e + case None => e + } + case _ => e + } + } + + def infer_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match { + case s: CDefMPort => + mports(s.name) = s.direction + s map infer_mdir_e(mports, MRead) + case s: Connect => + infer_mdir_e(mports, MRead)(s.expr) + infer_mdir_e(mports, MWrite)(s.loc) + s + case s: PartialConnect => + infer_mdir_e(mports, MRead)(s.expr) + infer_mdir_e(mports, MWrite)(s.loc) + s + case s => s map infer_mdir_s(mports) map infer_mdir_e(mports, MRead) + } + + def set_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match { + case s: CDefMPort => s copy (direction = mports(s.name)) + case s => s map set_mdir_s(mports) + } + + def infer_mdir(m: DefModule): DefModule = { + val mports = new MPortDirMap + m map infer_mdir_s(mports) map set_mdir_s(mports) + } + + def run(c: Circuit): Circuit = + c copy (modules = (c.modules map infer_mdir)) +} diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 54ef6003..118e547c 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -5,7 +5,8 @@ import firrtl.passes._ import Annotations._ class ReplSeqMemSpec extends SimpleTransformSpec { - + val passSeq = Seq( + ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) def transforms (writer: java.io.Writer) = Seq( new Chisel3ToHighFirrtl(), new IRToWorkingIR(), @@ -14,6 +15,8 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new passes.InferReadWrite(TransID(-1)), new passes.ReplSeqMem(TransID(-2)), new MiddleFirrtlToLowFirrtl(), + (new Transform with SimpleRun { + def execute(c: ir.Circuit, a: AnnotationMap) = run(c, passSeq) }), new EmitFirrtl(writer) ) @@ -97,27 +100,24 @@ circuit sram6t : input io_wdata : UInt<32> input io_raddr : UInt<8> output io_rdata : UInt<32> - + inst mem of mem node T_0 = eq(io_wen, UInt<1>("h0")) node T_1 = and(io_en, T_0) wire T_2 : UInt<8> node GEN_0 = validif(T_1, io_raddr) - node GEN_1 = mux(T_1, UInt<1>("h1"), UInt<1>("h0")) node T_4 = and(io_en, io_wen) + node GEN_4 = validif(T_4, io_wdata) node GEN_2 = validif(T_4, io_waddr) - node GEN_3 = validif(T_4, clk) - node GEN_4 = mux(T_4, UInt<1>("h1"), UInt<1>("h0")) - node GEN_5 = validif(T_4, io_wdata) - node GEN_6 = mux(T_4, UInt<1>("h1"), UInt<1>("h0")) + node GEN_5 = validif(T_4, clk) io_rdata <= mem.R0_data mem.R0_addr <= bits(T_2, 6, 0) mem.R0_clk <= clk - mem.R0_en <= GEN_1 + mem.R0_en <= T_1 mem.W0_addr <= bits(GEN_2, 6, 0) - mem.W0_clk <= GEN_3 - mem.W0_en <= GEN_4 - mem.W0_data <= GEN_5 + mem.W0_clk <= GEN_5 + mem.W0_en <= T_4 + mem.W0_data <= GEN_4 T_2 <= GEN_0 extmodule mem_ext : @@ -140,16 +140,16 @@ circuit sram6t : input W0_en : UInt<1> input W0_clk : Clock input W0_data : UInt<32> - + inst mem_ext of mem_ext mem_ext.R0_addr <= R0_addr mem_ext.R0_en <= R0_en mem_ext.R0_clk <= R0_clk - R0_data <= bits(mem_ext.R0_data, 31, 0) + R0_data <= mem_ext.R0_data mem_ext.W0_addr <= W0_addr mem_ext.W0_en <= W0_en mem_ext.W0_clk <= W0_clk - mem_ext.W0_data <= W0_data + mem_ext.W0_data <= W0_data """.stripMargin val checkConf = """name mem_ext depth 128 width 32 ports write,read """ @@ -170,4 +170,4 @@ circuit sram6t : // readwrite vs. no readwrite // redundant memories (multiple instances of the same type of memory) // mask + no mask -// conf
\ No newline at end of file +// conf diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 2d1bbdc1..7feb4a00 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -203,7 +203,8 @@ class UnitTests extends FirrtlFlatSpec { InferWidths, PullMuxes, ExpandConnects, - RemoveAccesses + RemoveAccesses, + ConstProp ) val input = """circuit AssignViaDeref : @@ -221,7 +222,7 @@ class UnitTests extends FirrtlFlatSpec { val check = Seq( """wire GEN_0 : { a : UInt<8>}""", """GEN_0.a <= table[0].a""", - """when eq(UInt<1>("h1"), UInt<1>("h1")) :""", + """when UInt<1>("h1") :""", """GEN_0.a <= table[1].a""", """wire GEN_1 : UInt<8>""", """when eq(UInt<1>("h0"), GEN_0.a) :""", |
