aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu2016-09-08 15:48:22 -0700
committerGitHub2016-09-08 15:48:22 -0700
commit765b880d4a56875c1ed07f4a0e8904c74a92dc0b (patch)
tree6f11b15ed7516bc8816ec0d45f505dd0e4014613 /src
parent303bad7db4354429c1992233fe0bfd1e8ce7f93e (diff)
parent864a3978cf94f336187831773dfc2c9f9ea064c8 (diff)
Merge pull request #283 from ucb-bar/refactor_expand_whens
Refactor Passes
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Emitter.scala20
-rw-r--r--src/main/scala/firrtl/Utils.scala22
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala208
-rw-r--r--src/main/scala/firrtl/passes/InferTypes.scala161
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala328
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala941
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala30
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala256
-rw-r--r--src/main/scala/firrtl/passes/Resolves.scala163
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala30
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala5
11 files changed, 1061 insertions, 1103 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 378eac6d..b5d212e4 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -62,6 +62,26 @@ case class VRandom(width: BigInt) extends Expression {
}
class VerilogEmitter extends Emitter {
val tab = " "
+ def AND(e1: WrappedExpression, e2: WrappedExpression): Expression = {
+ if (e1 == e2) e1.e1
+ else if ((e1 == we(zero)) | (e2 == we(zero))) zero
+ else if (e1 == we(one)) e2.e1
+ else if (e2 == we(one)) e1.e1
+ else DoPrim(And, Seq(e1.e1, e2.e1), Nil, UIntType(IntWidth(1)))
+ }
+ def OR(e1: WrappedExpression, e2: WrappedExpression): Expression = {
+ if (e1 == e2) e1.e1
+ else if ((e1 == we(one)) | (e2 == we(one))) one
+ else if (e1 == we(zero)) e2.e1
+ else if (e2 == we(zero)) e1.e1
+ else DoPrim(Or, Seq(e1.e1, e2.e1), Nil, UIntType(IntWidth(1)))
+ }
+ def NOT(e: WrappedExpression): Expression = {
+ if (e == we(one)) zero
+ else if (e == we(zero)) one
+ else DoPrim(Eq, Seq(e.e1, zero), Nil, UIntType(IntWidth(1)))
+ }
+
def wref(n: String, t: Type) = WRef(n, t, ExpKind(), UNKNOWNGENDER)
def remove_root(ex: Expression): Expression = ex match {
case ex: WSubField => ex.exp match {
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 29c37294..572d1ccc 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -98,28 +98,6 @@ object Utils extends LazyLogging {
val ix = if (i < 0) ((-1 * i) - 1) else i
ceil_log2(ix + 1) + 1
}
- def EQV (e1:Expression,e2:Expression) : Expression =
- DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe)
- // TODO: these should be fixed
- def AND (e1:WrappedExpression,e2:WrappedExpression) : Expression = {
- if (e1 == e2) e1.e1
- else if ((e1 == we(zero)) | (e2 == we(zero))) zero
- else if (e1 == we(one)) e2.e1
- else if (e2 == we(one)) e1.e1
- else DoPrim(And,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1)))
- }
- def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = {
- if (e1 == e2) e1.e1
- else if ((e1 == we(one)) | (e2 == we(one))) one
- else if (e1 == we(zero)) e2.e1
- else if (e2 == we(zero)) e1.e1
- else DoPrim(Or,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1)))
- }
- def NOT (e1:WrappedExpression) : Expression = {
- if (e1 == we(one)) zero
- else if (e1 == we(zero)) one
- else DoPrim(Eq,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1)))
- }
def create_mask(dt: Type): Type = dt match {
case t: VectorType => VectorType(create_mask(t.tpe),t.size)
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index 3d26298a..7c013b51 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -34,11 +34,6 @@ import firrtl.Mappers._
import firrtl.PrimOps._
import firrtl.WrappedExpression._
-// Datastructures
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.LinkedHashMap
-import scala.collection.mutable.ArrayBuffer
-
import annotation.tailrec
/** Expand Whens
@@ -50,138 +45,129 @@ import annotation.tailrec
*/
object ExpandWhens extends Pass {
def name = "Expand Whens"
+ type NodeMap = collection.mutable.HashMap[MemoizedHash[Expression], String]
+ type Netlist = collection.mutable.LinkedHashMap[WrappedExpression, Expression]
+ type Simlist = collection.mutable.ArrayBuffer[Statement]
+ type Defaults = Seq[collection.mutable.Map[WrappedExpression, Expression]]
// ========== Expand When Utilz ==========
- private def getEntries(
- hash: LinkedHashMap[WrappedExpression, Expression],
- exps: Seq[Expression]): LinkedHashMap[WrappedExpression, Expression] = {
- val hashx = LinkedHashMap[WrappedExpression, Expression]()
- exps foreach (e => if (hash.contains(e)) hashx(e) = hash(e))
- hashx
- }
private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = {
def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, Default))
val exps = create_exps(WRef(n, t, ExpKind(), g))
- val expsx = ArrayBuffer[Expression]()
- for (j <- 0 until exps.size) {
- getGender(t, j, g) match {
- case (BIGENDER | FEMALE) => expsx += exps(j)
- case _ =>
+ (exps.zipWithIndex foldLeft Seq[Expression]()){
+ case (expsx, (exp, j)) => getGender(t, j, g) match {
+ case (BIGENDER | FEMALE) => expsx :+ exp
+ case _ => expsx
}
}
- expsx
}
- private def expandNetlist(netlist: LinkedHashMap[WrappedExpression, Expression]) =
- netlist map { case (k, v) =>
- v match {
- case WInvalid() => IsInvalid(NoInfo, k.e1)
- case _ => Connect(NoInfo, k.e1, v)
- }
+ private def expandNetlist(netlist: Netlist) =
+ netlist map {
+ case (k, WInvalid()) => IsInvalid(NoInfo, k.e1)
+ case (k, v) => Connect(NoInfo, k.e1, v)
}
// Searches nested scopes of defaults for lvalue
// defaults uses mutable Map because we are searching LinkedHashMaps and conversion to immutable is VERY slow
@tailrec
- private def getDefault(
- lvalue: WrappedExpression,
- defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]]): Option[Expression] = {
- if (defaults.isEmpty) None
- else if (defaults.head.contains(lvalue)) defaults.head.get(lvalue)
- else getDefault(lvalue, defaults.tail)
+ private def getDefault(lvalue: WrappedExpression, defaults: Defaults): Option[Expression] = {
+ defaults match {
+ case Nil => None
+ case head :: tail => head get lvalue match {
+ case Some(p) => Some(p)
+ case None => getDefault(lvalue, tail)
+ }
+ }
}
+ private def AND(e1: Expression, e2: Expression) =
+ DoPrim(And, Seq(e1, e2), Nil, UIntType(IntWidth(1)))
+ private def NOT(e: Expression) =
+ DoPrim(Eq, Seq(e, zero), Nil, UIntType(IntWidth(1)))
+
// ------------ Pass -------------------
def run(c: Circuit): Circuit = {
- def expandWhens(m: Module): (LinkedHashMap[WrappedExpression, Expression], ArrayBuffer[Statement], Statement) = {
+ def expandWhens(m: Module): (Netlist, Simlist, Statement) = {
val namespace = Namespace(m)
- val simlist = ArrayBuffer[Statement]()
+ val simlist = new Simlist
+ val nodes = new NodeMap
// defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow
- def expandWhens(
- netlist: LinkedHashMap[WrappedExpression, Expression],
- defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]],
- p: Expression)
- (s: Statement): Statement = {
- s match {
- case w: DefWire =>
- getFemaleRefs(w.name, w.tpe, BIGENDER) foreach (ref => netlist(ref) = WVoid())
- w
- case r: DefRegister =>
- getFemaleRefs(r.name, r.tpe, BIGENDER) foreach (ref => netlist(ref) = ref)
- r
- case c: Connect =>
- netlist(c.loc) = c.expr
- EmptyStmt
- case c: IsInvalid =>
- netlist(c.expr) = WInvalid()
- EmptyStmt
- case s: Conditionally =>
- val memos = ArrayBuffer[Statement]()
-
- val conseqNetlist = LinkedHashMap[WrappedExpression, Expression]()
- val altNetlist = LinkedHashMap[WrappedExpression, Expression]()
- val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq)
- val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt)
+ def expandWhens(netlist: Netlist,
+ defaults: Defaults,
+ p: Expression)
+ (s: Statement): Statement = s match {
+ case w: DefWire =>
+ netlist ++= (getFemaleRefs(w.name, w.tpe, BIGENDER) map (ref => we(ref) -> WVoid()))
+ w
+ case r: DefRegister =>
+ netlist ++= (getFemaleRefs(r.name, r.tpe, BIGENDER) map (ref => we(ref) -> ref))
+ r
+ case c: Connect =>
+ netlist(c.loc) = c.expr
+ EmptyStmt
+ case c: IsInvalid =>
+ netlist(c.expr) = WInvalid()
+ EmptyStmt
+ case s: Conditionally =>
+ val conseqNetlist = new Netlist
+ val altNetlist = new Netlist
+ val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq)
+ val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt)
- (conseqNetlist.keySet ++ altNetlist.keySet) foreach { lvalue =>
- // Defaults in netlist get priority over those in defaults
- val default = if (netlist.contains(lvalue)) netlist.get(lvalue) else getDefault(lvalue, defaults)
- val res = default match {
- case Some(defaultValue) =>
- val trueValue = conseqNetlist.getOrElse(lvalue, defaultValue)
- val falseValue = altNetlist.getOrElse(lvalue, defaultValue)
- (trueValue, falseValue) match {
- case (WInvalid(), WInvalid()) => WInvalid()
- case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe)
- case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe)
- case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv))
- }
- case None =>
- // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
- conseqNetlist.getOrElse(lvalue, altNetlist(lvalue))
- }
-
- val memoNode = DefNode(s.info, namespace.newTemp, res)
- val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE)
- memos += memoNode
- netlist(lvalue) = memoExpr
+ val memos = (conseqNetlist.keys ++ altNetlist.keys) map { lvalue =>
+ // Defaults in netlist get priority over those in defaults
+ val default = netlist get lvalue match {
+ case Some(v) => Some(v)
+ case None => getDefault(lvalue, defaults)
}
- Block(Seq(conseqStmt, altStmt) ++ memos)
-
- case s: Print =>
- if(weq(p, one)) {
- simlist += s
- } else {
- simlist += Print(s.info, s.string, s.args, s.clk, AND(p, s.en))
+ val res = default match {
+ case Some(defaultValue) =>
+ val trueValue = conseqNetlist getOrElse (lvalue, defaultValue)
+ val falseValue = altNetlist getOrElse (lvalue, defaultValue)
+ (trueValue, falseValue) match {
+ case (WInvalid(), WInvalid()) => WInvalid()
+ case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe)
+ case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe)
+ case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv))
+ }
+ case None =>
+ // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
+ conseqNetlist getOrElse (lvalue, altNetlist(lvalue))
}
- EmptyStmt
- case s: Stop =>
- if (weq(p, one)) {
- simlist += s
- } else {
- simlist += Stop(s.info, s.ret, s.clk, AND(p, s.en))
+
+ nodes get res match {
+ case Some(name) =>
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE)
+ EmptyStmt
+ case None =>
+ val name = namespace.newTemp
+ nodes(res) = name
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE)
+ DefNode(s.info, name, res)
}
- EmptyStmt
- case s => s map expandWhens(netlist, defaults, p)
- }
+ }
+ Block(Seq(conseqStmt, altStmt) ++ memos)
+ case s: Print =>
+ simlist += (if (weq(p, one)) s else Print(s.info, s.string, s.args, s.clk, AND(p, s.en)))
+ EmptyStmt
+ case s: Stop =>
+ simlist += (if (weq(p, one)) s else Stop(s.info, s.ret, s.clk, AND(p, s.en)))
+ EmptyStmt
+ case s => s map expandWhens(netlist, defaults, p)
}
- val netlist = LinkedHashMap[WrappedExpression, Expression]()
-
+ val netlist = new Netlist
// Add ports to netlist
- m.ports foreach { port =>
- getFemaleRefs(port.name, port.tpe, to_gender(port.direction)) foreach (ref => netlist(ref) = WVoid())
- }
- val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
-
- (netlist, simlist, bodyx)
+ netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) =>
+ getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid())
+ })
+ (netlist, simlist, expandWhens(netlist, Seq(netlist), one)(m.body))
}
- val modulesx = c.modules map { m =>
- m match {
- case m: ExtModule => m
- case m: Module =>
- val (netlist, simlist, bodyx) = expandWhens(m)
- val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist)
- Module(m.info, m.name, m.ports, newBody)
- }
+ val modulesx = c.modules map {
+ case m: ExtModule => m
+ case m: Module =>
+ val (netlist, simlist, bodyx) = expandWhens(m)
+ val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist)
+ Module(m.info, m.name, m.ports, newBody)
}
Circuit(c.info, modulesx, c.main)
}
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala
new file mode 100644
index 00000000..b36298e8
--- /dev/null
+++ b/src/main/scala/firrtl/passes/InferTypes.scala
@@ -0,0 +1,161 @@
+/*
+Copyright (c) 2014 - 2016 The Regents of the University of
+California (Regents). All Rights Reserved. Redistribution and use in
+source and binary forms, with or without modification, are permitted
+provided that the following conditions are met:
+ * Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer in the documentation and/or other materials
+ provided with the distribution.
+ * Neither the name of the Regents nor the names of its contributors
+ may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS,
+ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
+REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF
+ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION
+TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
+MODIFICATIONS.
+*/
+
+package firrtl.passes
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+
+object InferTypes extends Pass {
+ def name = "Infer Types"
+ type TypeMap = collection.mutable.LinkedHashMap[String, Type]
+
+ def run(c: Circuit): Circuit = {
+ val namespace = Namespace()
+ val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
+
+ def remove_unknowns_w(w: Width): Width = w match {
+ case UnknownWidth => VarWidth(namespace.newName("w"))
+ case w => w
+ }
+
+ def remove_unknowns(t: Type): Type =
+ t map remove_unknowns map remove_unknowns_w
+
+ def infer_types_e(types: TypeMap)(e: Expression): Expression =
+ e map infer_types_e(types) match {
+ case e: WRef => e copy (tpe = types(e.name))
+ case e: WSubField => e copy (tpe = field_type(e.exp.tpe, e.name))
+ case e: WSubIndex => e copy (tpe = sub_type(e.exp.tpe))
+ case e: WSubAccess => e copy (tpe = sub_type(e.exp.tpe))
+ case e: DoPrim => PrimOps.set_primop_type(e)
+ case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval))
+ case e: ValidIf => e copy (tpe = e.value.tpe)
+ case e @ (_: UIntLiteral | _: SIntLiteral) => e
+ }
+
+ def infer_types_s(types: TypeMap)(s: Statement): Statement = s match {
+ case s: WDefInstance =>
+ val t = mtypes(s.module)
+ types(s.name) = t
+ s copy (tpe = t)
+ case s: DefWire =>
+ val t = remove_unknowns(get_type(s))
+ types(s.name) = t
+ s copy (tpe = t)
+ case s: DefNode =>
+ val sx = s map infer_types_e(types)
+ val t = remove_unknowns(get_type(sx))
+ types(s.name) = t
+ sx map infer_types_e(types)
+ case s: DefRegister =>
+ val t = remove_unknowns(get_type(s))
+ types(s.name) = t
+ s copy (tpe = t) map infer_types_e(types)
+ case s: DefMemory =>
+ val t = remove_unknowns(get_type(s))
+ types(s.name) = t
+ s copy (dataType = remove_unknowns(s.dataType))
+ case s => s map infer_types_s(types) map infer_types_e(types)
+ }
+
+ def infer_types_p(types: TypeMap)(p: Port): Port = {
+ val t = remove_unknowns(p.tpe)
+ types(p.name) = t
+ p copy (tpe = t)
+ }
+
+ def infer_types(m: DefModule): DefModule = {
+ val types = new TypeMap
+ m map infer_types_p(types) map infer_types_s(types)
+ }
+
+ c copy (modules = (c.modules map infer_types))
+ }
+}
+
+object CInferTypes extends Pass {
+ def name = "CInfer Types"
+ type TypeMap = collection.mutable.LinkedHashMap[String, Type]
+
+ def run(c: Circuit): Circuit = {
+ val namespace = Namespace()
+ val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
+
+ def infer_types_e(types: TypeMap)(e: Expression) : Expression =
+ e map infer_types_e(types) match {
+ case (e: Reference) => e copy (tpe = (types getOrElse (e.name, UnknownType)))
+ case (e: SubField) => e copy (tpe = field_type(e.expr.tpe, e.name))
+ case (e: SubIndex) => e copy (tpe = sub_type(e.expr.tpe))
+ case (e: SubAccess) => e copy (tpe = sub_type(e.expr.tpe))
+ case (e: DoPrim) => PrimOps.set_primop_type(e)
+ case (e: Mux) => e copy (tpe = mux_type(e.tval,e.tval))
+ case (e: ValidIf) => e copy (tpe = e.value.tpe)
+ case e @ (_: UIntLiteral | _: SIntLiteral) => e
+ }
+
+ def infer_types_s(types: TypeMap)(s: Statement): Statement = s match {
+ case (s: DefRegister) =>
+ types(s.name) = s.tpe
+ s map infer_types_e(types)
+ case (s: DefWire) =>
+ types(s.name) = s.tpe
+ s
+ case (s: DefNode) =>
+ types(s.name) = get_type(s)
+ s
+ case (s: DefMemory) =>
+ types(s.name) = get_type(s)
+ s
+ case (s: CDefMPort) =>
+ val t = types getOrElse(s.mem, UnknownType)
+ types(s.name) = t
+ s copy (tpe = t)
+ case (s: CDefMemory) =>
+ types(s.name) = s.tpe
+ s
+ case (s: DefInstance) =>
+ types(s.name) = mtypes(s.module)
+ s
+ case (s) => s map infer_types_s(types) map infer_types_e(types)
+ }
+
+ def infer_types_p(types: TypeMap)(p: Port): Port = {
+ types(p.name) = p.tpe
+ p
+ }
+
+ def infer_types(m: DefModule): DefModule = {
+ val types = new TypeMap
+ m map infer_types_p(types) map infer_types_s(types)
+ }
+
+ c copy (modules = (c.modules map infer_types))
+ }
+}
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
new file mode 100644
index 00000000..5a81c268
--- /dev/null
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -0,0 +1,328 @@
+/*
+Copyright (c) 2014 - 2016 The Regents of the University of
+California (Regents). All Rights Reserved. Redistribution and use in
+source and binary forms, with or without modification, are permitted
+provided that the following conditions are met:
+ * Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer in the documentation and/or other materials
+ provided with the distribution.
+ * Neither the name of the Regents nor the names of its contributors
+ may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS,
+ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
+REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF
+ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION
+TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
+MODIFICATIONS.
+*/
+
+package firrtl.passes
+
+// Datastructures
+import scala.collection.mutable.{LinkedHashMap, HashMap, HashSet, ArrayBuffer}
+import scala.collection.immutable.ListMap
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import firrtl.PrimOps._
+import firrtl.WrappedExpression._
+
+object InferWidths extends Pass {
+ def name = "Infer Widths"
+
+ def solve_constraints(l: Seq[WGeq]): LinkedHashMap[String, Width] = {
+ def unique(ls: Seq[Width]) : Seq[Width] =
+ (ls map (new WrappedWidth(_))).distinct map (_.w)
+ def make_unique(ls: Seq[WGeq]): ListMap[String,Width] = {
+ (ls foldLeft ListMap[String, Width]())((h, g) => g.loc match {
+ case w: VarWidth => h get w.name match {
+ case None => h + (w.name -> g.exp)
+ case Some(p) => h + (w.name -> MaxWidth(Seq(g.exp, p)))
+ }
+ case _ => h
+ })
+ }
+ def simplify(w: Width): Width = w map simplify match {
+ case (w: MinWidth) => MinWidth(unique((w.args foldLeft Seq[Width]()){
+ case (res, w: MinWidth) => res ++ w.args
+ case (res, w) => res :+ w
+ }))
+ case (w: MaxWidth) => MaxWidth(unique((w.args foldLeft Seq[Width]()){
+ case (res, w: MaxWidth) => res ++ w.args
+ case (res, w) => res :+ w
+ }))
+ case (w: PlusWidth) => (w.arg1, w.arg2) match {
+ case (w1: IntWidth, w2 :IntWidth) => IntWidth(w1.width + w2.width)
+ case _ => w
+ }
+ case (w: MinusWidth) => (w.arg1, w.arg2) match {
+ case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
+ case _ => w
+ }
+ case (w: ExpWidth) => w.arg1 match {
+ case (w1: IntWidth) => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong))
+ case (w1) => w
+ }
+ case _ => w
+ }
+
+ def substitute(h: LinkedHashMap[String, Width])(w: Width): Width = {
+ //;println-all-debug(["Substituting for [" w "]"])
+ val wx = simplify(w)
+ //;println-all-debug(["After Simplify: [" wx "]"])
+ (wx map substitute(h)) match {
+ //;("matched println-debugvarwidth!")
+ case w: VarWidth => h get w.name match {
+ case None => w
+ case Some(p) =>
+ //;println-debug("Contained!")
+ //;println-all-debug(["Width: " w])
+ //;println-all-debug(["Accessed: " h[name(w)]])
+ val t = simplify(substitute(h)(p))
+ h(w.name) = t
+ t
+ }
+ case w => w
+ //;println-all-debug(["not varwidth!" w])
+ }
+ }
+
+ def b_sub(h: LinkedHashMap[String, Width])(w: Width): Width = {
+ w map b_sub(h) match {
+ case w: VarWidth => h getOrElse (w.name, w)
+ case w => w
+ }
+ }
+
+ def remove_cycle(n: String)(w: Width): Width = {
+ //;println-all-debug(["Removing cycle for " n " inside " w])
+ (w map remove_cycle(n)) match {
+ case w: MaxWidth => MaxWidth(w.args filter {
+ case w: VarWidth => !(n equals w.name)
+ case w => true
+ })
+ case w: MinusWidth => w.arg1 match {
+ case v: VarWidth if n == v.name => v
+ case v => w
+ }
+ case w => w
+ }
+ //;println-all-debug(["After removing cycle for " n ", returning " wx])
+ }
+
+ def hasVarWidth(n: String)(w: Width): Boolean = {
+ var has = false
+ def rec(w: Width): Width = {
+ w match {
+ case w: VarWidth if w.name == n => has = true
+ case w =>
+ }
+ w map rec
+ }
+ rec(w)
+ has
+ }
+
+ //; Forward solve
+ //; Returns a solved list where each constraint undergoes:
+ //; 1) Continuous Solving (using triangular solving)
+ //; 2) Remove Cycles
+ //; 3) Move to solved if not self-recursive
+ val u = make_unique(l)
+
+ //println("======== UNIQUE CONSTRAINTS ========")
+ //for (x <- u) { println(x) }
+ //println("====================================")
+
+ val f = LinkedHashMap[String, Width]()
+ val o = ArrayBuffer[String]()
+ for ((n, e) <- u) {
+ //println("==== SOLUTIONS TABLE ====")
+ //for (x <- f) println(x)
+ //println("=========================")
+
+ val e_sub = substitute(f)(e)
+
+ //println("Solving " + n + " => " + e)
+ //println("After Substitute: " + n + " => " + e_sub)
+ //println("==== SOLUTIONS TABLE (Post Substitute) ====")
+ //for (x <- f) println(x)
+ //println("=========================")
+
+ val ex = remove_cycle(n)(e_sub)
+
+ //println("After Remove Cycle: " + n + " => " + ex)
+ if (!hasVarWidth(n)(ex)) {
+ //println("Not rec!: " + n + " => " + ex)
+ //println("Adding [" + n + "=>" + ex + "] to Solutions Table")
+ f(n) = ex
+ o += n
+ }
+ }
+
+ //println("Forward Solved Constraints")
+ //for (x <- f) println(x)
+
+ //; Backwards Solve
+ val b = LinkedHashMap[String, Width]()
+ for (i <- (o.size - 1) to 0 by -1) {
+ val n = o(i) // Should visit `o` backward
+ /*
+ println("SOLVE BACK: [" + n + " => " + f(n) + "]")
+ println("==== SOLUTIONS TABLE ====")
+ for (x <- b) println(x)
+ println("=========================")
+ */
+ val ex = simplify(b_sub(b)(f(n)))
+ /*
+ println("BACK RETURN: [" + n + " => " + ex + "]")
+ */
+ b(n) = ex
+ /*
+ println("==== SOLUTIONS TABLE (Post backsolve) ====")
+ for (x <- b) println(x)
+ println("=========================")
+ */
+ }
+ b
+ }
+
+ def run (c: Circuit): Circuit = {
+ val v = ArrayBuffer[WGeq]()
+
+ def get_constraints_t(t1: Type, t2: Type, f: Orientation): Seq[WGeq] = (t1,t2) match {
+ case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
+ case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
+ case (t1: BundleType, t2: BundleType) =>
+ (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
+ res ++ get_constraints_t(f1.tpe, f2.tpe, times(f1.flip, f))
+ }
+ case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe, f)
+ }
+
+ def get_constraints_e(e: Expression): Expression = {
+ e match {
+ case (e: Mux) => v ++= Seq(
+ WGeq(width_BANG(e.cond), IntWidth(1)),
+ WGeq(IntWidth(1), width_BANG(e.cond))
+ )
+ case _ =>
+ }
+ e map get_constraints_e
+ }
+
+ def get_constraints_s(s: Statement): Statement = {
+ s match {
+ case (s: Connect) =>
+ val n = get_size(s.loc.tpe)
+ val locs = create_exps(s.loc)
+ val exps = create_exps(s.expr)
+ v ++= ((locs zip exps).zipWithIndex map {case ((locx, expx), i) =>
+ get_flip(s.loc.tpe, i, Default) match {
+ case Default => WGeq(width_BANG(locx), width_BANG(expx))
+ case Flip => WGeq(width_BANG(expx), width_BANG(locx))
+ }
+ })
+ case (s: PartialConnect) =>
+ val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default)
+ val locs = create_exps(s.loc)
+ val exps = create_exps(s.expr)
+ v ++= (ls map {case (x, y) =>
+ val locx = locs(x)
+ val expx = exps(y)
+ get_flip(s.loc.tpe, x, Default) match {
+ case Default => WGeq(width_BANG(locx), width_BANG(expx))
+ case Flip => WGeq(width_BANG(expx), width_BANG(locx))
+ }
+ })
+ case (s:DefRegister) => v ++= (Seq(
+ WGeq(width_BANG(s.reset), IntWidth(1)),
+ WGeq(IntWidth(1), width_BANG(s.reset))
+ ) ++ get_constraints_t(s.tpe, s.init.tpe, Default))
+ case (s:Conditionally) => v ++= Seq(
+ WGeq(width_BANG(s.pred), IntWidth(1)),
+ WGeq(IntWidth(1), width_BANG(s.pred))
+ )
+ case _ =>
+ }
+ s map get_constraints_e map get_constraints_s
+ }
+
+ c.modules foreach (_ map get_constraints_s)
+
+ //println-debug("======== ALL CONSTRAINTS ========")
+ //for x in v do : println-debug(x)
+ //println-debug("=================================")
+ val h = solve_constraints(v)
+ //println-debug("======== SOLVED CONSTRAINTS ========")
+ //for x in h do : println-debug(x)
+ //println-debug("====================================")
+
+ def evaluate(w: Width): Width = {
+ def map2(a: Option[BigInt], b: Option[BigInt], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
+ for (a_num <- a; b_num <- b) yield f(a_num, b_num)
+ def reduceOptions(l: Seq[Option[BigInt]], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
+ l.reduce(map2(_, _, f))
+
+ // This function shouldn't be necessary
+ // Added as protection in case a constraint accidentally uses MinWidth/MaxWidth
+ // without any actual Widths. This should be elevated to an earlier error
+ def forceNonEmpty(in: Seq[Option[BigInt]], default: Option[BigInt]): Seq[Option[BigInt]] =
+ if (in.isEmpty) Seq(default)
+ else in
+
+ def solve(w: Width): Option[BigInt] = w match {
+ case (w: VarWidth) =>
+ for{
+ v <- h.get(w.name) if !v.isInstanceOf[VarWidth]
+ result <- solve(v)
+ } yield result
+ case (w: MaxWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), Some(BigInt(0))), max)
+ case (w: MinWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), None), min)
+ case (w: PlusWidth) => map2(solve(w.arg1), solve(w.arg2), {_ + _})
+ case (w: MinusWidth) => map2(solve(w.arg1), solve(w.arg2), {_ - _})
+ case (w: ExpWidth) => map2(Some(BigInt(2)), solve(w.arg1), pow_minus_one)
+ case (w: IntWidth) => Some(w.width)
+ case (w) => println(w); error("Shouldn't be here"); None;
+ }
+
+ solve(w) match {
+ case None => w
+ case Some(s) => IntWidth(s)
+ }
+ }
+
+ def reduce_var_widths_w(w: Width): Width = {
+ //println-all-debug(["REPLACE: " w])
+ evaluate(w)
+ //println-all-debug(["WITH: " wx])
+ }
+
+ def reduce_var_widths_t(t: Type): Type = {
+ t map reduce_var_widths_t map reduce_var_widths_w
+ }
+
+ def reduce_var_widths_s(s: Statement): Statement = {
+ s map reduce_var_widths_s map reduce_var_widths_t
+ }
+
+ def reduce_var_widths_p(p: Port): Port = {
+ Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe))
+ }
+
+ InferTypes.run(c.copy(modules = c.modules map (_
+ map reduce_var_widths_p
+ map reduce_var_widths_s)))
+ }
+}
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index c143212e..b9808485 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -94,513 +94,6 @@ object ToWorkingIR extends Pass {
}
}
-object ResolveKinds extends Pass {
- private var mname = ""
- def name = "Resolve Kinds"
- def run (c:Circuit): Circuit = {
- def resolve_kinds (m:DefModule, c:Circuit):DefModule = {
- val kinds = LinkedHashMap[String,Kind]()
- def resolve (body:Statement) = {
- def resolve_expr (e:Expression):Expression = {
- e match {
- case e:WRef => WRef(e.name,e.tpe,kinds(e.name),e.gender)
- case e => e map (resolve_expr)
- }
- }
- def resolve_stmt (s:Statement):Statement = s map (resolve_stmt) map (resolve_expr)
- resolve_stmt(body)
- }
-
- def find (m:DefModule) = {
- def find_stmt (s:Statement):Statement = {
- s match {
- case s:DefWire => kinds(s.name) = WireKind()
- case s:DefNode => kinds(s.name) = NodeKind()
- case s:DefRegister => kinds(s.name) = RegKind()
- case s:WDefInstance => kinds(s.name) = InstanceKind()
- case s:DefMemory => kinds(s.name) = MemKind(s.readers ++ s.writers ++ s.readwriters)
- case s => false
- }
- s map (find_stmt)
- }
- m.ports.foreach { p => kinds(p.name) = PortKind() }
- m match {
- case m:Module => find_stmt(m.body)
- case m:ExtModule => false
- }
- }
-
- mname = m.name
- find(m)
- m match {
- case m:Module => {
- val bodyx = resolve(m.body)
- Module(m.info,m.name,m.ports,bodyx)
- }
- case m:ExtModule => ExtModule(m.info,m.name,m.ports)
- }
- }
- val modulesx = c.modules.map(m => resolve_kinds(m,c))
- Circuit(c.info,modulesx,c.main)
- }
-}
-
-object InferTypes extends Pass {
- private var mname = ""
- def name = "Infer Types"
- def set_type (s:Statement, t:Type) : Statement = {
- s match {
- case s:DefWire => DefWire(s.info,s.name,t)
- case s:DefRegister => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init)
- case s:DefMemory => DefMemory(s.info,s.name,t,s.depth,s.writeLatency,s.readLatency,s.readers,s.writers,s.readwriters)
- case s:DefNode => s
- }
- }
- def remove_unknowns_w (w:Width)(implicit namespace: Namespace):Width = {
- w match {
- case UnknownWidth => VarWidth(namespace.newName("w"))
- case w => w
- }
- }
- def remove_unknowns (t:Type)(implicit n: Namespace): Type = mapr(remove_unknowns_w _,t)
- def run (c:Circuit): Circuit = {
- val module_types = LinkedHashMap[String,Type]()
- implicit val wnamespace = Namespace()
- def infer_types (m:DefModule) : DefModule = {
- val types = LinkedHashMap[String,Type]()
- def infer_types_e (e:Expression) : Expression = {
- e map (infer_types_e) match {
- case e:ValidIf => ValidIf(e.cond,e.value,e.value.tpe)
- case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender)
- case e:WSubField => WSubField(e.exp,e.name,field_type(e.exp.tpe,e.name),e.gender)
- case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(e.exp.tpe),e.gender)
- case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(e.exp.tpe),e.gender)
- case e:DoPrim => set_primop_type(e)
- case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval))
- case e:UIntLiteral => e
- case e:SIntLiteral => e
- }
- }
- def infer_types_s (s:Statement) : Statement = {
- s match {
- case s:DefRegister => {
- val t = remove_unknowns(get_type(s))
- types(s.name) = t
- set_type(s,t) map (infer_types_e)
- }
- case s:DefWire => {
- val sx = s map(infer_types_e)
- val t = remove_unknowns(get_type(sx))
- types(s.name) = t
- set_type(sx,t)
- }
- case s:DefNode => {
- val sx = s map (infer_types_e)
- val t = remove_unknowns(get_type(sx))
- types(s.name) = t
- set_type(sx,t)
- }
- case s:DefMemory => {
- val t = remove_unknowns(get_type(s))
- types(s.name) = t
- val dt = remove_unknowns(s.dataType)
- set_type(s,dt)
- }
- case s:WDefInstance => {
- types(s.name) = module_types(s.module)
- WDefInstance(s.info,s.name,s.module,module_types(s.module))
- }
- case s => s map (infer_types_s) map (infer_types_e)
- }
- }
-
- mname = m.name
- m.ports.foreach(p => types(p.name) = p.tpe)
- m match {
- case m:Module => Module(m.info,m.name,m.ports,infer_types_s(m.body))
- case m:ExtModule => m
- }
- }
-
- val modulesx = c.modules.map {
- m => {
- mname = m.name
- val portsx = m.ports.map(p => Port(p.info,p.name,p.direction,remove_unknowns(p.tpe)))
- m match {
- case m:Module => Module(m.info,m.name,portsx,m.body)
- case m:ExtModule => ExtModule(m.info,m.name,portsx)
- }
- }
- }
- modulesx.foreach(m => module_types(m.name) = module_type(m))
- Circuit(c.info,modulesx.map({m => mname = m.name; infer_types(m)}) , c.main )
- }
-}
-
-object ResolveGenders extends Pass {
- private var mname = ""
- def name = "Resolve Genders"
- def run (c:Circuit): Circuit = {
- def resolve_e (g:Gender)(e:Expression) : Expression = {
- e match {
- case e:WRef => WRef(e.name,e.tpe,e.kind,g)
- case e:WSubField => {
- val expx =
- field_flip(e.exp.tpe,e.name) match {
- case Default => resolve_e(g)(e.exp)
- case Flip => resolve_e(swap(g))(e.exp)
- }
- WSubField(expx,e.name,e.tpe,g)
- }
- case e:WSubIndex => {
- val expx = resolve_e(g)(e.exp)
- WSubIndex(expx,e.value,e.tpe,g)
- }
- case e:WSubAccess => {
- val expx = resolve_e(g)(e.exp)
- val indexx = resolve_e(MALE)(e.index)
- WSubAccess(expx,indexx,e.tpe,g)
- }
- case e => e map (resolve_e(g))
- }
- }
-
- def resolve_s (s:Statement) : Statement = {
- s match {
- case s:IsInvalid => {
- val expx = resolve_e(FEMALE)(s.expr)
- IsInvalid(s.info,expx)
- }
- case s:Connect => {
- val locx = resolve_e(FEMALE)(s.loc)
- val expx = resolve_e(MALE)(s.expr)
- Connect(s.info,locx,expx)
- }
- case s:PartialConnect => {
- val locx = resolve_e(FEMALE)(s.loc)
- val expx = resolve_e(MALE)(s.expr)
- PartialConnect(s.info,locx,expx)
- }
- case s => s map (resolve_e(MALE)) map (resolve_s)
- }
- }
- val modulesx = c.modules.map {
- m => {
- mname = m.name
- m match {
- case m:Module => {
- val bodyx = resolve_s(m.body)
- Module(m.info,m.name,m.ports,bodyx)
- }
- case m:ExtModule => m
- }
- }
- }
- Circuit(c.info,modulesx,c.main)
- }
-}
-
-object InferWidths extends Pass {
- def name = "Infer Widths"
- var mname = ""
- def solve_constraints (l:Seq[WGeq]) : LinkedHashMap[String,Width] = {
- def unique (ls:Seq[Width]) : Seq[Width] = ls.map(w => new WrappedWidth(w)).distinct.map(_.w)
- def make_unique (ls:Seq[WGeq]) : LinkedHashMap[String,Width] = {
- val h = LinkedHashMap[String,Width]()
- for (g <- ls) {
- (g.loc) match {
- case (w:VarWidth) => {
- val n = w.name
- if (h.contains(n)) h(n) = MaxWidth(Seq(g.exp,h(n))) else h(n) = g.exp
- }
- case (w) => w
- }
- }
- h
- }
- def simplify (w:Width) : Width = {
- (w map (simplify)) match {
- case (w:MinWidth) => {
- val v = ArrayBuffer[Width]()
- for (wx <- w.args) {
- (wx) match {
- case (wx:MinWidth) => for (x <- wx.args) { v += x }
- case (wx) => v += wx } }
- MinWidth(unique(v)) }
- case (w:MaxWidth) => {
- val v = ArrayBuffer[Width]()
- for (wx <- w.args) {
- (wx) match {
- case (wx:MaxWidth) => for (x <- wx.args) { v += x }
- case (wx) => v += wx } }
- MaxWidth(unique(v)) }
- case (w:PlusWidth) => {
- (w.arg1,w.arg2) match {
- case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width + w2.width)
- case (w1,w2) => w }}
- case (w:MinusWidth) => {
- (w.arg1,w.arg2) match {
- case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width - w2.width)
- case (w1,w2) => w }}
- case (w:ExpWidth) => {
- (w.arg1) match {
- case (w1:IntWidth) => IntWidth(BigInt((scala.math.pow(2,w1.width.toDouble) - 1).toLong))
- case (w1) => w }}
- case (w) => w } }
- def substitute (h:LinkedHashMap[String,Width])(w:Width) : Width = {
- //;println-all-debug(["Substituting for [" w "]"])
- val wx = simplify(w)
- //;println-all-debug(["After Simplify: [" wx "]"])
- (simplify(w) map (substitute(h))) match {
- case (w:VarWidth) => {
- //;("matched println-debugvarwidth!")
- if (h.contains(w.name)) {
- //;println-debug("Contained!")
- //;println-all-debug(["Width: " w])
- //;println-all-debug(["Accessed: " h[name(w)]])
- val t = simplify(substitute(h)(h(w.name)))
- //;val t = h[name(w)]
- //;println-all-debug(["Width after sub: " t])
- h(w.name) = t
- t
- } else w
- }
- case (w) => w
- //;println-all-debug(["not varwidth!" w])
- }
- }
- def b_sub (h:LinkedHashMap[String,Width])(w:Width) : Width = {
- (w map (b_sub(h))) match {
- case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w
- case (w) => w
- }
- }
- def remove_cycle (n:String)(w:Width) : Width = {
- //;println-all-debug(["Removing cycle for " n " inside " w])
- val wx = (w map (remove_cycle(n))) match {
- case (w:MaxWidth) => MaxWidth(w.args.filter{ w => {
- w match {
- case (w:VarWidth) => !(n equals w.name)
- case (w) => true
- }}})
- case (w:MinusWidth) => {
- w.arg1 match {
- case (v:VarWidth) => if (n == v.name) v else w
- case (v) => w }}
- case (w) => w
- }
- //;println-all-debug(["After removing cycle for " n ", returning " wx])
- wx
- }
- def self_rec (n:String,w:Width) : Boolean = {
- var has = false
- def look (w:Width) : Width = {
- (w map (look)) match {
- case (w:VarWidth) => if (w.name == n) has = true
- case (w) => w }
- w }
- look(w)
- has }
-
- //; Forward solve
- //; Returns a solved list where each constraint undergoes:
- //; 1) Continuous Solving (using triangular solving)
- //; 2) Remove Cycles
- //; 3) Move to solved if not self-recursive
- val u = make_unique(l)
-
- //println("======== UNIQUE CONSTRAINTS ========")
- //for (x <- u) { println(x) }
- //println("====================================")
-
-
- val f = LinkedHashMap[String,Width]()
- val o = ArrayBuffer[String]()
- for (x <- u) {
- //println("==== SOLUTIONS TABLE ====")
- //for (x <- f) println(x)
- //println("=========================")
-
- val (n, e) = (x._1, x._2)
- val e_sub = substitute(f)(e)
-
- //println("Solving " + n + " => " + e)
- //println("After Substitute: " + n + " => " + e_sub)
- //println("==== SOLUTIONS TABLE (Post Substitute) ====")
- //for (x <- f) println(x)
- //println("=========================")
-
- val ex = remove_cycle(n)(e_sub)
-
- //println("After Remove Cycle: " + n + " => " + ex)
- if (!self_rec(n,ex)) {
- //println("Not rec!: " + n + " => " + ex)
- //println("Adding [" + n + "=>" + ex + "] to Solutions Table")
- o += n
- f(n) = ex
- }
- }
-
- //println("Forward Solved Constraints")
- //for (x <- f) println(x)
-
- //; Backwards Solve
- val b = LinkedHashMap[String,Width]()
- for (i <- 0 until o.size) {
- val n = o(o.size - 1 - i)
- /*
- println("SOLVE BACK: [" + n + " => " + f(n) + "]")
- println("==== SOLUTIONS TABLE ====")
- for (x <- b) println(x)
- println("=========================")
- */
- val ex = simplify(b_sub(b)(f(n)))
- /*
- println("BACK RETURN: [" + n + " => " + ex + "]")
- */
- b(n) = ex
- /*
- println("==== SOLUTIONS TABLE (Post backsolve) ====")
- for (x <- b) println(x)
- println("=========================")
- */
- }
- b
- }
-
- def width_BANG (t:Type) : Width = {
- (t) match {
- case (t:UIntType) => t.width
- case (t:SIntType) => t.width
- case ClockType => IntWidth(1)
- case (t) => error("No width!"); IntWidth(-1) } }
- def width_BANG (e:Expression) : Width = width_BANG(e.tpe)
-
- def reduce_var_widths(c: Circuit, h: LinkedHashMap[String,Width]): Circuit = {
- def evaluate(w: Width): Width = {
- def map2(a: Option[BigInt], b: Option[BigInt], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
- for (a_num <- a; b_num <- b) yield f(a_num, b_num)
- def reduceOptions(l: Seq[Option[BigInt]], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
- l.reduce(map2(_, _, f))
-
- // This function shouldn't be necessary
- // Added as protection in case a constraint accidentally uses MinWidth/MaxWidth
- // without any actual Widths. This should be elevated to an earlier error
- def forceNonEmpty(in: Seq[Option[BigInt]], default: Option[BigInt]): Seq[Option[BigInt]] =
- if(in.isEmpty) Seq(default)
- else in
-
-
- def solve(w: Width): Option[BigInt] = w match {
- case (w: VarWidth) =>
- for{
- v <- h.get(w.name) if !v.isInstanceOf[VarWidth]
- result <- solve(v)
- } yield result
- case (w: MaxWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), Some(BigInt(0))), max)
- case (w: MinWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), None), min)
- case (w: PlusWidth) => map2(solve(w.arg1), solve(w.arg2), {_ + _})
- case (w: MinusWidth) => map2(solve(w.arg1), solve(w.arg2), {_ - _})
- case (w: ExpWidth) => map2(Some(BigInt(2)), solve(w.arg1), pow_minus_one)
- case (w: IntWidth) => Some(w.width)
- case (w) => println(w); error("Shouldn't be here"); None;
- }
-
- val s = solve(w)
- (s) match {
- case Some(s) => IntWidth(s)
- case (s) => w
- }
- }
-
- def reduce_var_widths_w (w:Width) : Width = {
- //println-all-debug(["REPLACE: " w])
- val wx = evaluate(w)
- //println-all-debug(["WITH: " wx])
- wx
- }
- def reduce_var_widths_s (s: Statement): Statement = {
- def onType(t: Type): Type = t map onType map reduce_var_widths_w
- s map reduce_var_widths_s map onType
- }
-
- val modulesx = c.modules.map{ m => {
- val portsx = m.ports.map{ p => {
- Port(p.info,p.name,p.direction,mapr(reduce_var_widths_w _,p.tpe)) }}
- (m) match {
- case (m:ExtModule) => ExtModule(m.info,m.name,portsx)
- case (m:Module) =>
- mname = m.name
- Module(m.info,m.name,portsx,m.body map reduce_var_widths_s _) }}}
- InferTypes.run(Circuit(c.info,modulesx,c.main))
- }
-
- def run (c:Circuit): Circuit = {
- val v = ArrayBuffer[WGeq]()
- def constrain (w1:Width,w2:Width) : Unit = v += WGeq(w1,w2)
- def get_constraints_t (t1:Type,t2:Type,f:Orientation) : Unit = {
- (t1,t2) match {
- case (t1:UIntType,t2:UIntType) => constrain(t1.width,t2.width)
- case (t1:SIntType,t2:SIntType) => constrain(t1.width,t2.width)
- case (t1:BundleType,t2:BundleType) => {
- (t1.fields,t2.fields).zipped.foreach{ (f1,f2) => {
- get_constraints_t(f1.tpe,f2.tpe,times(f1.flip,f)) }}}
- case (t1:VectorType,t2:VectorType) => get_constraints_t(t1.tpe,t2.tpe,f) }}
- def get_constraints_e (e:Expression) : Expression = {
- (e map (get_constraints_e)) match {
- case (e:Mux) => {
- constrain(width_BANG(e.cond),IntWidth(1))
- constrain(IntWidth(1),width_BANG(e.cond))
- e }
- case (e) => e }}
- def get_constraints (s:Statement) : Statement = {
- (s map (get_constraints_e)) match {
- case (s:Connect) => {
- val n = get_size(s.loc.tpe)
- val ce_loc = create_exps(s.loc)
- val ce_exp = create_exps(s.expr)
- for (i <- 0 until n) {
- val locx = ce_loc(i)
- val expx = ce_exp(i)
- get_flip(s.loc.tpe,i,Default) match {
- case Default => constrain(width_BANG(locx),width_BANG(expx))
- case Flip => constrain(width_BANG(expx),width_BANG(locx)) }}
- s }
- case (s:PartialConnect) => {
- val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
- for (x <- ls) {
- val locx = create_exps(s.loc)(x._1)
- val expx = create_exps(s.expr)(x._2)
- get_flip(s.loc.tpe,x._1,Default) match {
- case Default => constrain(width_BANG(locx),width_BANG(expx))
- case Flip => constrain(width_BANG(expx),width_BANG(locx)) }}
- s }
- case (s:DefRegister) => {
- constrain(width_BANG(s.reset),IntWidth(1))
- constrain(IntWidth(1),width_BANG(s.reset))
- get_constraints_t(s.tpe,s.init.tpe,Default)
- s }
- case (s:Conditionally) => {
- v += WGeq(width_BANG(s.pred),IntWidth(1))
- v += WGeq(IntWidth(1),width_BANG(s.pred))
- s map (get_constraints) }
- case (s) => s map (get_constraints) }}
-
- for (m <- c.modules) {
- (m) match {
- case (m:Module) => mname = m.name; get_constraints(m.body)
- case (m) => false }}
- //println-debug("======== ALL CONSTRAINTS ========")
- //for x in v do : println-debug(x)
- //println-debug("=================================")
- val h = solve_constraints(v)
- //println-debug("======== SOLVED CONSTRAINTS ========")
- //for x in h do : println-debug(x)
- //println-debug("====================================")
- reduce_var_widths(Circuit(c.info,c.modules,c.main),h)
- }
-}
-
object PullMuxes extends Pass {
def name = "Pull Muxes"
def run(c: Circuit): Circuit = {
@@ -837,437 +330,3 @@ object VerilogRename extends Pass {
Circuit(c.info,modulesx,c.main)
}
}
-
-object CInferTypes extends Pass {
- def name = "CInfer Types"
- var mname = ""
- def set_type (s:Statement, t:Type) : Statement = {
- (s) match {
- case (s:DefWire) => DefWire(s.info,s.name,t)
- case (s:DefRegister) => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init)
- case (s:CDefMemory) => CDefMemory(s.info,s.name,t,s.size,s.seq)
- case (s:CDefMPort) => CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction)
- case (s:DefNode) => s
- }
- }
-
- def to_field (p:Port) : Field = {
- if (p.direction == Output) Field(p.name,Default,p.tpe)
- else if (p.direction == Input) Field(p.name,Flip,p.tpe)
- else error("Shouldn't be here"); Field(p.name,Flip,p.tpe)
- }
- def module_type (m:DefModule) : Type = BundleType(m.ports.map(p => to_field(p)))
- def field_type (v:Type,s:String) : Type = {
- (v) match {
- case (v:BundleType) => {
- val ft = v.fields.find(p => p.name == s)
- if (ft != None) ft.get.tpe
- else UnknownType
- }
- case (v) => UnknownType
- }
- }
- def sub_type (v:Type) : Type =
- (v) match {
- case (v:VectorType) => v.tpe
- case (v) => UnknownType
- }
- def run (c:Circuit) : Circuit = {
- val module_types = LinkedHashMap[String,Type]()
- def infer_types (m:DefModule) : DefModule = {
- val types = LinkedHashMap[String,Type]()
- def infer_types_e (e:Expression) : Expression = {
- e map infer_types_e match {
- case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType))
- case (e:SubField) => SubField(e.expr,e.name,field_type(e.expr.tpe,e.name))
- case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(e.expr.tpe))
- case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(e.expr.tpe))
- case (e:DoPrim) => set_primop_type(e)
- case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval))
- case (e:ValidIf) => ValidIf(e.cond,e.value,e.value.tpe)
- case (_:UIntLiteral | _:SIntLiteral) => e
- }
- }
- def infer_types_s (s:Statement) : Statement = {
- s match {
- case (s:DefRegister) => {
- types(s.name) = s.tpe
- s map infer_types_e
- s
- }
- case (s:DefWire) => {
- types(s.name) = s.tpe
- s
- }
- case (s:DefNode) => {
- val sx = s map infer_types_e
- val t = get_type(sx)
- types(s.name) = t
- sx
- }
- case (s:DefMemory) => {
- types(s.name) = get_type(s)
- s
- }
- case (s:CDefMPort) => {
- val t = types.getOrElse(s.mem,UnknownType)
- types(s.name) = t
- CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction)
- }
- case (s:CDefMemory) => {
- types(s.name) = s.tpe
- s
- }
- case (s:DefInstance) => {
- types(s.name) = module_types.getOrElse(s.module,UnknownType)
- s
- }
- case (s) => s map infer_types_s map infer_types_e
- }
- }
- for (p <- m.ports) {
- types(p.name) = p.tpe
- }
- m match {
- case (m:Module) => Module(m.info,m.name,m.ports,infer_types_s(m.body))
- case (m:ExtModule) => m
- }
- }
-
- //; MAIN
- for (m <- c.modules) {
- module_types(m.name) = module_type(m)
- }
- val modulesx = c.modules.map(m => infer_types(m))
- Circuit(c.info, modulesx, c.main)
- }
-}
-
-object CInferMDir extends Pass {
- def name = "CInfer MDir"
- var mname = ""
- def run (c:Circuit) : Circuit = {
- def infer_mdir (m:DefModule) : DefModule = {
- val mports = LinkedHashMap[String,MPortDir]()
- def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = {
- (e map (infer_mdir_e(dir))) match {
- case (e:Reference) => {
- if (mports.contains(e.name)) {
- val new_mport_dir = {
- (mports(e.name),dir) match {
- case (MInfer,MInfer) => error("Shouldn't be here")
- case (MInfer,MWrite) => MWrite
- case (MInfer,MRead) => MRead
- case (MInfer,MReadWrite) => MReadWrite
- case (MWrite,MInfer) => error("Shouldn't be here")
- case (MWrite,MWrite) => MWrite
- case (MWrite,MRead) => MReadWrite
- case (MWrite,MReadWrite) => MReadWrite
- case (MRead,MInfer) => error("Shouldn't be here")
- case (MRead,MWrite) => MReadWrite
- case (MRead,MRead) => MRead
- case (MRead,MReadWrite) => MReadWrite
- case (MReadWrite,MInfer) => error("Shouldn't be here")
- case (MReadWrite,MWrite) => MReadWrite
- case (MReadWrite,MRead) => MReadWrite
- case (MReadWrite,MReadWrite) => MReadWrite
- }
- }
- mports(e.name) = new_mport_dir
- }
- e
- }
- case (e) => e
- }
- }
- def infer_mdir_s (s:Statement) : Statement = {
- (s) match {
- case (s:CDefMPort) => {
- mports(s.name) = s.direction
- s map (infer_mdir_e(MRead))
- }
- case (s:Connect) => {
- infer_mdir_e(MRead)(s.expr)
- infer_mdir_e(MWrite)(s.loc)
- s
- }
- case (s:PartialConnect) => {
- infer_mdir_e(MRead)(s.expr)
- infer_mdir_e(MWrite)(s.loc)
- s
- }
- case (s) => s map (infer_mdir_s) map (infer_mdir_e(MRead))
- }
- }
- def set_mdir_s (s:Statement) : Statement = {
- (s) match {
- case (s:CDefMPort) =>
- CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name))
- case (s) => s map (set_mdir_s)
- }
- }
- (m) match {
- case (m:Module) => {
- infer_mdir_s(m.body)
- Module(m.info,m.name,m.ports,set_mdir_s(m.body))
- }
- case (m:ExtModule) => m
- }
- }
-
- //; MAIN
- Circuit(c.info, c.modules.map(m => infer_mdir(m)), c.main)
- }
-}
-
-case class MPort( val name : String, val clk : Expression)
-case class MPorts( val readers : ArrayBuffer[MPort], val writers : ArrayBuffer[MPort], val readwriters : ArrayBuffer[MPort])
-case class DataRef( val exp : Expression, val male : String, val female : String, val mask : String, val rdwrite : Boolean)
-
-object RemoveCHIRRTL extends Pass {
- def name = "Remove CHIRRTL"
- var mname = ""
- def create_exps (e:Expression) : Seq[Expression] = e match {
- case (e:Mux) =>
- val e1s = create_exps(e.tval)
- val e2s = create_exps(e.fval)
- (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2)))
- case (e:ValidIf) =>
- create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe))
- case (e) => (e.tpe) match {
- case (_:GroundType) => Seq(e)
- case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
- exps ++ create_exps(SubField(e,f.name,f.tpe)))
- case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
- exps ++ create_exps(SubIndex(e,i,t.tpe)))
- case UnknownType => Seq(e)
- }
- }
- def run (c:Circuit) : Circuit = {
- def remove_chirrtl_m (m:Module) : Module = {
- val hash = LinkedHashMap[String,MPorts]()
- val repl = LinkedHashMap[String,DataRef]()
- val raddrs = HashMap[String, Expression]()
- val ut = UnknownType
- val mport_types = LinkedHashMap[String,Type]()
- val smems = HashSet[String]()
- def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]())
- def collect_smems_and_mports (s:Statement) : Statement = {
- (s) match {
- case (s:CDefMemory) if s.seq =>
- smems += s.name
- s
- case (s:CDefMPort) => {
- val mports = hash.getOrElse(s.mem,EMPs())
- s.direction match {
- case MRead => mports.readers += MPort(s.name,s.exps(1))
- case MWrite => mports.writers += MPort(s.name,s.exps(1))
- case MReadWrite => mports.readwriters += MPort(s.name,s.exps(1))
- }
- hash(s.mem) = mports
- s
- }
- case (s) => s map (collect_smems_and_mports)
- }
- }
- def collect_refs (s:Statement) : Statement = {
- (s) match {
- case (s:CDefMemory) => {
- mport_types(s.name) = s.tpe
- val stmts = ArrayBuffer[Statement]()
- val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size))))
- val tdata = s.tpe
- def set_poison (vec:Seq[MPort],addr:String) : Unit = {
- for (r <- vec ) {
- stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),addr,taddr))
- stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),"clk",taddr))
- }
- }
- def set_enable (vec:Seq[MPort],en:String) : Unit = {
- for (r <- vec ) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),en,taddr),zero)
- }}
- def set_wmode (vec:Seq[MPort],wmode:String) : Unit = {
- for (r <- vec) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),wmode,taddr),zero)
- }}
- def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = {
- val tmask = create_mask(s.tpe)
- for (r <- vec ) {
- stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),data,tdata))
- for (x <- create_exps(SubField(SubField(Reference(s.name,ut),r.name,ut),mask,tmask)) ) {
- stmts += Connect(s.info,x,zero)
- }}}
- val rds = (hash.getOrElse(s.name,EMPs())).readers
- set_poison(rds,"addr")
- set_enable(rds,"en")
- val wrs = (hash.getOrElse(s.name,EMPs())).writers
- set_poison(wrs,"addr")
- set_enable(wrs,"en")
- set_write(wrs,"data","mask")
- val rws = (hash.getOrElse(s.name,EMPs())).readwriters
- set_poison(rws,"addr")
- set_wmode(rws,"wmode")
- set_enable(rws,"en")
- set_write(rws,"wdata","wmask")
- val read_l = if (s.seq) 1 else 0
- val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name))
- Block(Seq(mem,Block(stmts)))
- }
- case (s:CDefMPort) => {
- mport_types(s.name) = mport_types(s.mem)
- val addrs = ArrayBuffer[String]()
- val clks = ArrayBuffer[String]()
- val ens = ArrayBuffer[String]()
- val masks = ArrayBuffer[String]()
- s.direction match {
- case MReadWrite => {
- repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"rdata","wdata","wmask",true)
- addrs += "addr"
- clks += "clk"
- ens += "en"
- masks += "wmask"
- }
- case MWrite => {
- repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","mask",false)
- addrs += "addr"
- clks += "clk"
- ens += "en"
- masks += "mask"
- }
- case MRead => {
- repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","blah",false)
- addrs += "addr"
- clks += "clk"
- s.exps(0) match {
- case e: Reference if smems(s.mem) =>
- raddrs(e.name) = SubField(SubField(Reference(s.mem,ut),s.name,ut),"en",ut)
- case _ => ens += "en"
- }
- }
- }
- val stmts = ArrayBuffer[Statement]()
- for (x <- addrs ) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(0))
- }
- for (x <- clks ) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(1))
- }
- for (x <- ens ) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),one)
- }
- Block(stmts)
- }
- case (s) => s map (collect_refs)
- }
- }
- def remove_chirrtl_s (s:Statement) : Statement = {
- var has_write_mport = false
- var has_read_mport: Option[Expression] = None
- var has_readwrite_mport: Option[Expression] = None
- def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = {
- (e) match {
- case (e:Reference) if repl contains e.name =>
- val vt = repl(e.name)
- g match {
- case MALE => SubField(vt.exp,vt.male,e.tpe)
- case FEMALE => {
- has_write_mport = true
- if (vt.rdwrite)
- has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1))))
- SubField(vt.exp,vt.female,e.tpe)
- }
- }
- case (e:Reference) if g == FEMALE && (raddrs contains e.name) =>
- has_read_mport = Some(raddrs(e.name))
- e
- case (e:Reference) => e
- case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.expr),remove_chirrtl_e(MALE)(e.index),e.tpe)
- case (e) => e map (remove_chirrtl_e(g))
- }
- }
- def get_mask (e:Expression) : Expression = {
- (e map (get_mask)) match {
- case (e:Reference) => {
- if (repl.contains(e.name)) {
- val vt = repl(e.name)
- val t = create_mask(e.tpe)
- SubField(vt.exp,vt.mask,t)
- } else e
- }
- case (e) => e
- }
- }
- (s) match {
- case (s:DefNode) => {
- val stmts = ArrayBuffer[Statement]()
- val valuex = remove_chirrtl_e(MALE)(s.value)
- stmts += DefNode(s.info,s.name,valuex)
- has_read_mport match {
- case None =>
- case Some(en) => stmts += Connect(s.info,en,one)
- }
- if (stmts.size > 1) Block(stmts)
- else stmts(0)
- }
- case (s:Connect) => {
- val stmts = ArrayBuffer[Statement]()
- val rocx = remove_chirrtl_e(MALE)(s.expr)
- val locx = remove_chirrtl_e(FEMALE)(s.loc)
- stmts += Connect(s.info,locx,rocx)
- has_read_mport match {
- case None =>
- case Some(en) => stmts += Connect(s.info,en,one)
- }
- if (has_write_mport) {
- val e = get_mask(s.loc)
- for (x <- create_exps(e) ) {
- stmts += Connect(s.info,x,one)
- }
- has_readwrite_mport match {
- case None =>
- case Some(wmode) => stmts += Connect(s.info,wmode,one)
- }
- }
- if (stmts.size > 1) Block(stmts)
- else stmts(0)
- }
- case (s:PartialConnect) => {
- val stmts = ArrayBuffer[Statement]()
- val locx = remove_chirrtl_e(FEMALE)(s.loc)
- val rocx = remove_chirrtl_e(MALE)(s.expr)
- stmts += PartialConnect(s.info,locx,rocx)
- has_read_mport match {
- case None =>
- case Some(en) => stmts += Connect(s.info,en,one)
- }
- if (has_write_mport) {
- val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
- val locs = create_exps(get_mask(s.loc))
- for (x <- ls ) {
- val locx = locs(x._1)
- stmts += Connect(s.info,locx,one)
- }
- has_readwrite_mport match {
- case None =>
- case Some(wmode) => stmts += Connect(s.info,wmode,one)
- }
- }
- if (stmts.size > 1) Block(stmts)
- else stmts(0)
- }
- case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE))
- }
- }
- collect_smems_and_mports(m.body)
- val sx = collect_refs(m.body)
- Module(m.info,m.name, m.ports, remove_chirrtl_s(sx))
- }
- val modulesx = c.modules.map{ m => {
- (m) match {
- case (m:Module) => remove_chirrtl_m(m)
- case (m:ExtModule) => m
- }}}
- Circuit(c.info,modulesx, c.main)
- }
-}
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index 880d6b1c..08f08eac 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -1,11 +1,11 @@
package firrtl.passes
+import firrtl.{WRef, WSubAccess, WSubIndex, WSubField, Namespace}
+import firrtl.PrimOps.{And, Eq}
import firrtl.ir._
-import firrtl.{WRef, WSubAccess, WSubIndex, WSubField}
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.WrappedExpression._
-import firrtl.Namespace
import scala.collection.mutable
@@ -13,6 +13,13 @@ import scala.collection.mutable
*/
object RemoveAccesses extends Pass {
def name = "Remove Accesses"
+
+ private def AND(e1: Expression, e2: Expression) =
+ DoPrim(And, Seq(e1, e2), Nil, UIntType(IntWidth(1)))
+
+ private def EQV(e1: Expression, e2: Expression): Expression =
+ DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe)
+
/** Container for a base expression and its corresponding guard
*/
private case class Location(base: Expression, guard: Expression)
@@ -53,13 +60,13 @@ object RemoveAccesses extends Pass {
/** Returns true if e contains a [[firrtl.WSubAccess]]
*/
private def hasAccess(e: Expression): Boolean = {
- var ret: Boolean = false
- def rec_has_access(e: Expression): Expression = {
- e match {
- case e : WSubAccess => ret = true
- case e =>
- }
- e map rec_has_access
+ var ret: Boolean = false
+ def rec_has_access(e: Expression): Expression = {
+ e match {
+ case e : WSubAccess => ret = true
+ case e =>
+ }
+ e map rec_has_access
}
rec_has_access(e)
ret
@@ -150,10 +157,9 @@ object RemoveAccesses extends Pass {
Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body)))
}
- val newModules = c.modules.map {
+ c copy (modules = (c.modules map {
case m: ExtModule => m
case m: Module => remove_m(m)
- }
- Circuit(c.info, newModules, c.main)
+ }))
}
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
new file mode 100644
index 00000000..2bae92a7
--- /dev/null
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -0,0 +1,256 @@
+/*
+Copyright (c) 2014 - 2016 The Regents of the University of
+California (Regents). All Rights Reserved. Redistribution and use in
+source and binary forms, with or without modification, are permitted
+provided that the following conditions are met:
+ * Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer in the documentation and/or other materials
+ provided with the distribution.
+ * Neither the name of the Regents nor the names of its contributors
+ may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS,
+ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
+REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF
+ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION
+TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
+MODIFICATIONS.
+*/
+
+package firrtl.passes
+
+// Datastructures
+import scala.collection.mutable.ArrayBuffer
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+
+case class MPort(name: String, clk: Expression)
+case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort])
+case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean)
+
+object RemoveCHIRRTL extends Pass {
+ def name = "Remove CHIRRTL"
+
+ val ut = UnknownType
+ type MPortMap = collection.mutable.LinkedHashMap[String, MPorts]
+ type SeqMemSet = collection.mutable.HashSet[String]
+ type MPortTypeMap = collection.mutable.LinkedHashMap[String, Type]
+ type DataRefMap = collection.mutable.LinkedHashMap[String, DataRef]
+ type AddrMap = collection.mutable.HashMap[String, Expression]
+
+ def create_exps(e: Expression): Seq[Expression] = e match {
+ case (e: Mux) =>
+ val e1s = create_exps(e.tval)
+ val e2s = create_exps(e.fval)
+ (e1s zip e2s) map { case (e1, e2) => Mux(e.cond, e1, e2, mux_type(e1, e2)) }
+ case (e: ValidIf) =>
+ create_exps(e.value) map (e1 => ValidIf(e.cond, e1, e1.tpe))
+ case (e) => (e.tpe) match {
+ case (_: GroundType) => Seq(e)
+ case (t: BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
+ exps ++ create_exps(SubField(e, f.name, f.tpe)))
+ case (t: VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
+ exps ++ create_exps(SubIndex(e, i, t.tpe)))
+ case UnknownType => Seq(e)
+ }
+ }
+
+ private def EMPs: MPorts = MPorts(ArrayBuffer[MPort](), ArrayBuffer[MPort](), ArrayBuffer[MPort]())
+
+ def collect_smems_and_mports(mports: MPortMap, smems: SeqMemSet)(s: Statement): Statement = {
+ s match {
+ case (s:CDefMemory) if s.seq => smems += s.name
+ case (s:CDefMPort) =>
+ val p = mports getOrElse (s.mem, EMPs)
+ s.direction match {
+ case MRead => p.readers += MPort(s.name,s.exps(1))
+ case MWrite => p.writers += MPort(s.name,s.exps(1))
+ case MReadWrite => p.readwriters += MPort(s.name,s.exps(1))
+ }
+ mports(s.mem) = p
+ case s =>
+ }
+ s map collect_smems_and_mports(mports, smems)
+ }
+
+ def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap,
+ refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = s match {
+ case (s: CDefMemory) =>
+ types(s.name) = s.tpe
+ val taddr = UIntType(IntWidth(math.max(1, ceil_log2(s.size))))
+ val tdata = s.tpe
+ def set_poison(vec: Seq[MPort], addr: String) = vec flatMap (r => Seq(
+ IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), addr, taddr)),
+ IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "clk", taddr))
+ ))
+ def set_enable(vec: Seq[MPort], en: String) = vec map (r =>
+ Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), en, taddr), zero)
+ )
+ def set_wmode (vec: Seq[MPort], wmode: String) = vec map (r =>
+ Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), wmode, taddr), zero)
+ )
+ def set_write (vec: Seq[MPort], data: String, mask: String) = vec flatMap {r =>
+ val tmask = create_mask(s.tpe)
+ IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), data, tdata)) +:
+ (create_exps(SubField(SubField(Reference(s.name, ut), r.name, ut), mask, tmask))
+ map (Connect(s.info, _, zero))
+ )
+ }
+ val rds = (mports getOrElse (s.name, EMPs)).readers
+ val wrs = (mports getOrElse (s.name, EMPs)).writers
+ val rws = (mports getOrElse (s.name, EMPs)).readwriters
+ val stmts = set_poison(rds, "addr") ++
+ set_enable(rds, "en") ++
+ set_poison(wrs, "addr") ++
+ set_enable(wrs, "en") ++
+ set_write(wrs, "data", "mask") ++
+ set_poison(rws, "addr") ++
+ set_wmode(rws, "wmode") ++
+ set_enable(rws, "en") ++
+ set_write(rws, "wdata", "wmask")
+ val mem = DefMemory(s.info, s.name, s.tpe, s.size, 1, if (s.seq) 1 else 0,
+ rds map (_.name), wrs map (_.name), rws map (_.name))
+ Block(mem +: stmts)
+ case (s: CDefMPort) => {
+ types(s.name) = types(s.mem)
+ val addrs = ArrayBuffer[String]()
+ val clks = ArrayBuffer[String]()
+ val ens = ArrayBuffer[String]()
+ s.direction match {
+ case MReadWrite =>
+ refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "rdata", "wdata", "wmask", true)
+ addrs += "addr"
+ clks += "clk"
+ ens += "en"
+ case MWrite =>
+ refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "data", "data", "mask", false)
+ addrs += "addr"
+ clks += "clk"
+ ens += "en"
+ case MRead =>
+ refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "data", "data", "blah", false)
+ addrs += "addr"
+ clks += "clk"
+ s.exps.head match {
+ case e: Reference if smems(s.mem) =>
+ raddrs(e.name) = SubField(SubField(Reference(s.mem, ut), s.name, ut), "en", ut)
+ case _ => ens += "en"
+ }
+ }
+ Block(
+ (addrs map (x => Connect(s.info, SubField(SubField(Reference(s.mem, ut), s.name, ut), x, ut), s.exps(0)))) ++
+ (clks map (x => Connect(s.info, SubField(SubField(Reference(s.mem, ut), s.name, ut), x, ut), s.exps(1)))) ++
+ (ens map (x => Connect(s.info,SubField(SubField(Reference(s.mem,ut), s.name, ut), x, ut), one))))
+ }
+ case (s) => s map collect_refs(mports, smems, types, refs, raddrs)
+ }
+
+ def get_mask(refs: DataRefMap)(e: Expression): Expression =
+ e map get_mask(refs) match {
+ case e: Reference => refs get e.name match {
+ case None => e
+ case Some(p) => SubField(p.exp, p.mask, create_mask(e.tpe))
+ }
+ case e => e
+ }
+
+ def remove_chirrtl_s(refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = {
+ var has_write_mport = false
+ var has_readwrite_mport: Option[Expression] = None
+ var has_read_mport: Option[Expression] = None
+ def remove_chirrtl_e(g: Gender)(e: Expression): Expression = e match {
+ case Reference(name, tpe) => refs get name match {
+ case Some(p) => g match {
+ case FEMALE =>
+ has_write_mport = true
+ if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", UIntType(IntWidth(1))))
+ SubField(p.exp, p.female, tpe)
+ case MALE =>
+ SubField(p.exp, p.male, tpe)
+ }
+ case None => g match {
+ case FEMALE => raddrs get name match {
+ case Some(en) => has_read_mport = Some(en) ; e
+ case None => e
+ }
+ case MALE => e
+ }
+ }
+ case SubAccess(expr, index, tpe) => SubAccess(
+ remove_chirrtl_e(g)(expr), remove_chirrtl_e(MALE)(index), tpe)
+ case e => e map remove_chirrtl_e(g)
+ }
+ (s) match {
+ case DefNode(info, name, value) =>
+ val valuex = remove_chirrtl_e(MALE)(value)
+ val sx = DefNode(info, name, valuex)
+ has_read_mport match {
+ case None => sx
+ case Some(en) => Block(Seq(sx, Connect(info, en, one)))
+ }
+ case Connect(info, loc, expr) =>
+ val rocx = remove_chirrtl_e(MALE)(expr)
+ val locx = remove_chirrtl_e(FEMALE)(loc)
+ val sx = Connect(info, locx, rocx)
+ val stmts = ArrayBuffer[Statement]()
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(info, en, one)
+ }
+ if (has_write_mport) {
+ val locs = create_exps(get_mask(refs)(loc))
+ stmts ++= (locs map (x => Connect(info, x, one)))
+ has_readwrite_mport match {
+ case None =>
+ case Some(wmode) => stmts += Connect(info, wmode, one)
+ }
+ }
+ if (stmts.isEmpty) sx else Block(sx +: stmts)
+ case PartialConnect(info, loc, expr) =>
+ val locx = remove_chirrtl_e(FEMALE)(loc)
+ val rocx = remove_chirrtl_e(MALE)(expr)
+ val sx = PartialConnect(info, locx, rocx)
+ val stmts = ArrayBuffer[Statement]()
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(info, en, one)
+ }
+ if (has_write_mport) {
+ val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default)
+ val locs = create_exps(get_mask(refs)(loc))
+ stmts ++= (ls map { case (x, _) => Connect(info, locs(x), one) })
+ has_readwrite_mport match {
+ case None =>
+ case Some(wmode) => stmts += Connect(info, wmode, one)
+ }
+ }
+ if (stmts.isEmpty) sx else Block(sx +: stmts)
+ case s => s map remove_chirrtl_s(refs, raddrs) map remove_chirrtl_e(MALE)
+ }
+ }
+
+ def remove_chirrtl_m(m: DefModule): DefModule = {
+ val mports = new MPortMap
+ val smems = new SeqMemSet
+ val types = new MPortTypeMap
+ val refs = new DataRefMap
+ val raddrs = new AddrMap
+ (m map collect_smems_and_mports(mports, smems)
+ map collect_refs(mports, smems, types, refs, raddrs)
+ map remove_chirrtl_s(refs, raddrs))
+ }
+
+ def run(c: Circuit): Circuit =
+ c copy (modules = (c.modules map remove_chirrtl_m))
+}
diff --git a/src/main/scala/firrtl/passes/Resolves.scala b/src/main/scala/firrtl/passes/Resolves.scala
new file mode 100644
index 00000000..3100f0c3
--- /dev/null
+++ b/src/main/scala/firrtl/passes/Resolves.scala
@@ -0,0 +1,163 @@
+/*
+Copyright (c) 2014 - 2016 The Regents of the University of
+California (Regents). All Rights Reserved. Redistribution and use in
+source and binary forms, with or without modification, are permitted
+provided that the following conditions are met:
+ * Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer in the documentation and/or other materials
+ provided with the distribution.
+ * Neither the name of the Regents nor the names of its contributors
+ may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS,
+ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
+REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF
+ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION
+TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
+MODIFICATIONS.
+*/
+
+package firrtl.passes
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+
+object ResolveKinds extends Pass {
+ def name = "Resolve Kinds"
+ type KindMap = collection.mutable.LinkedHashMap[String, Kind]
+
+ def find_port(kinds: KindMap)(p: Port): Port = {
+ kinds(p.name) = PortKind() ; p
+ }
+
+ def find_stmt(kinds: KindMap)(s: Statement):Statement = {
+ s match {
+ case s: DefWire => kinds(s.name) = WireKind()
+ case s: DefNode => kinds(s.name) = NodeKind()
+ case s: DefRegister => kinds(s.name) = RegKind()
+ case s: WDefInstance => kinds(s.name) = InstanceKind()
+ case s: DefMemory => kinds(s.name) = MemKind(s.readers ++ s.writers ++ s.readwriters)
+ case s =>
+ }
+ s map find_stmt(kinds)
+ }
+
+ def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match {
+ case e: WRef => e copy (kind = kinds(e.name))
+ case e => e map resolve_expr(kinds)
+ }
+
+ def resolve_stmt(kinds: KindMap)(s: Statement): Statement =
+ s map resolve_stmt(kinds) map resolve_expr(kinds)
+
+ def resolve_kinds(m: DefModule): DefModule = {
+ val kinds = new KindMap
+ (m map find_port(kinds)
+ map find_stmt(kinds)
+ map resolve_stmt(kinds))
+ }
+
+ def run(c: Circuit): Circuit =
+ c copy (modules = (c.modules map resolve_kinds))
+}
+
+object ResolveGenders extends Pass {
+ def name = "Resolve Genders"
+ def resolve_e(g: Gender)(e: Expression): Expression = e match {
+ case e: WRef => e copy (gender = g)
+ case WSubField(exp, name, tpe, _) => WSubField(
+ Utils.field_flip(exp.tpe, name) match {
+ case Default => resolve_e(g)(exp)
+ case Flip => resolve_e(Utils.swap(g))(exp)
+ }, name, tpe, g)
+ case WSubIndex(exp, value, tpe, _) =>
+ WSubIndex(resolve_e(g)(exp), value, tpe, g)
+ case WSubAccess(exp, index, tpe, _) =>
+ WSubAccess(resolve_e(g)(exp), resolve_e(MALE)(index), tpe, g)
+ case e => e map resolve_e(g)
+ }
+
+ def resolve_s(s: Statement): Statement = s match {
+ case IsInvalid(info, expr) =>
+ IsInvalid(info, resolve_e(FEMALE)(expr))
+ case Connect(info, loc, expr) =>
+ Connect(info, resolve_e(FEMALE)(loc), resolve_e(MALE)(expr))
+ case PartialConnect(info, loc, expr) =>
+ PartialConnect(info, resolve_e(FEMALE)(loc), resolve_e(MALE)(expr))
+ case s => s map resolve_e(MALE) map resolve_s
+ }
+
+ def resolve_gender(m: DefModule): DefModule = m map resolve_s
+
+ def run(c: Circuit): Circuit =
+ c copy (modules = (c.modules map resolve_gender))
+}
+
+object CInferMDir extends Pass {
+ def name = "CInfer MDir"
+ type MPortDirMap = collection.mutable.LinkedHashMap[String, MPortDir]
+
+ def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = {
+ (e map infer_mdir_e(mports, dir)) match {
+ case e: Reference => mports get e.name match {
+ case Some(p) => mports(e.name) = (p, dir) match {
+ case (MInfer, MInfer) => Utils.error("Shouldn't be here")
+ case (MInfer, MWrite) => MWrite
+ case (MInfer, MRead) => MRead
+ case (MInfer, MReadWrite) => MReadWrite
+ case (MWrite, MInfer) => Utils.error("Shouldn't be here")
+ case (MWrite, MWrite) => MWrite
+ case (MWrite, MRead) => MReadWrite
+ case (MWrite, MReadWrite) => MReadWrite
+ case (MRead, MInfer) => Utils.error("Shouldn't be here")
+ case (MRead, MWrite) => MReadWrite
+ case (MRead, MRead) => MRead
+ case (MRead, MReadWrite) => MReadWrite
+ case (MReadWrite, MInfer) => Utils.error("Shouldn't be here")
+ case (MReadWrite, MWrite) => MReadWrite
+ case (MReadWrite, MRead) => MReadWrite
+ case (MReadWrite, MReadWrite) => MReadWrite
+ } ; e
+ case None => e
+ }
+ case _ => e
+ }
+ }
+
+ def infer_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
+ case s: CDefMPort =>
+ mports(s.name) = s.direction
+ s map infer_mdir_e(mports, MRead)
+ case s: Connect =>
+ infer_mdir_e(mports, MRead)(s.expr)
+ infer_mdir_e(mports, MWrite)(s.loc)
+ s
+ case s: PartialConnect =>
+ infer_mdir_e(mports, MRead)(s.expr)
+ infer_mdir_e(mports, MWrite)(s.loc)
+ s
+ case s => s map infer_mdir_s(mports) map infer_mdir_e(mports, MRead)
+ }
+
+ def set_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
+ case s: CDefMPort => s copy (direction = mports(s.name))
+ case s => s map set_mdir_s(mports)
+ }
+
+ def infer_mdir(m: DefModule): DefModule = {
+ val mports = new MPortDirMap
+ m map infer_mdir_s(mports) map set_mdir_s(mports)
+ }
+
+ def run(c: Circuit): Circuit =
+ c copy (modules = (c.modules map infer_mdir))
+}
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 54ef6003..118e547c 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -5,7 +5,8 @@ import firrtl.passes._
import Annotations._
class ReplSeqMemSpec extends SimpleTransformSpec {
-
+ val passSeq = Seq(
+ ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty)
def transforms (writer: java.io.Writer) = Seq(
new Chisel3ToHighFirrtl(),
new IRToWorkingIR(),
@@ -14,6 +15,8 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
new passes.InferReadWrite(TransID(-1)),
new passes.ReplSeqMem(TransID(-2)),
new MiddleFirrtlToLowFirrtl(),
+ (new Transform with SimpleRun {
+ def execute(c: ir.Circuit, a: AnnotationMap) = run(c, passSeq) }),
new EmitFirrtl(writer)
)
@@ -97,27 +100,24 @@ circuit sram6t :
input io_wdata : UInt<32>
input io_raddr : UInt<8>
output io_rdata : UInt<32>
-
+
inst mem of mem
node T_0 = eq(io_wen, UInt<1>("h0"))
node T_1 = and(io_en, T_0)
wire T_2 : UInt<8>
node GEN_0 = validif(T_1, io_raddr)
- node GEN_1 = mux(T_1, UInt<1>("h1"), UInt<1>("h0"))
node T_4 = and(io_en, io_wen)
+ node GEN_4 = validif(T_4, io_wdata)
node GEN_2 = validif(T_4, io_waddr)
- node GEN_3 = validif(T_4, clk)
- node GEN_4 = mux(T_4, UInt<1>("h1"), UInt<1>("h0"))
- node GEN_5 = validif(T_4, io_wdata)
- node GEN_6 = mux(T_4, UInt<1>("h1"), UInt<1>("h0"))
+ node GEN_5 = validif(T_4, clk)
io_rdata <= mem.R0_data
mem.R0_addr <= bits(T_2, 6, 0)
mem.R0_clk <= clk
- mem.R0_en <= GEN_1
+ mem.R0_en <= T_1
mem.W0_addr <= bits(GEN_2, 6, 0)
- mem.W0_clk <= GEN_3
- mem.W0_en <= GEN_4
- mem.W0_data <= GEN_5
+ mem.W0_clk <= GEN_5
+ mem.W0_en <= T_4
+ mem.W0_data <= GEN_4
T_2 <= GEN_0
extmodule mem_ext :
@@ -140,16 +140,16 @@ circuit sram6t :
input W0_en : UInt<1>
input W0_clk : Clock
input W0_data : UInt<32>
-
+
inst mem_ext of mem_ext
mem_ext.R0_addr <= R0_addr
mem_ext.R0_en <= R0_en
mem_ext.R0_clk <= R0_clk
- R0_data <= bits(mem_ext.R0_data, 31, 0)
+ R0_data <= mem_ext.R0_data
mem_ext.W0_addr <= W0_addr
mem_ext.W0_en <= W0_en
mem_ext.W0_clk <= W0_clk
- mem_ext.W0_data <= W0_data
+ mem_ext.W0_data <= W0_data
""".stripMargin
val checkConf = """name mem_ext depth 128 width 32 ports write,read """
@@ -170,4 +170,4 @@ circuit sram6t :
// readwrite vs. no readwrite
// redundant memories (multiple instances of the same type of memory)
// mask + no mask
-// conf \ No newline at end of file
+// conf
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index 2d1bbdc1..7feb4a00 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -203,7 +203,8 @@ class UnitTests extends FirrtlFlatSpec {
InferWidths,
PullMuxes,
ExpandConnects,
- RemoveAccesses
+ RemoveAccesses,
+ ConstProp
)
val input =
"""circuit AssignViaDeref :
@@ -221,7 +222,7 @@ class UnitTests extends FirrtlFlatSpec {
val check = Seq(
"""wire GEN_0 : { a : UInt<8>}""",
"""GEN_0.a <= table[0].a""",
- """when eq(UInt<1>("h1"), UInt<1>("h1")) :""",
+ """when UInt<1>("h1") :""",
"""GEN_0.a <= table[1].a""",
"""wire GEN_1 : UInt<8>""",
"""when eq(UInt<1>("h0"), GEN_0.a) :""",