aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes')
-rw-r--r--src/main/scala/firrtl/passes/CheckWidths.scala41
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala86
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemIR.scala5
3 files changed, 68 insertions, 64 deletions
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala
index 4a72b18c..061c6b16 100644
--- a/src/main/scala/firrtl/passes/CheckWidths.scala
+++ b/src/main/scala/firrtl/passes/CheckWidths.scala
@@ -5,9 +5,9 @@ package firrtl.passes
import firrtl._
import firrtl.ir._
import firrtl.PrimOps._
-import firrtl.Mappers._
+import firrtl.traversals.Foreachers._
import firrtl.Utils._
-import firrtl.annotations.{Target, TargetToken, CircuitTarget, ModuleTarget}
+import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken}
object CheckWidths extends Pass {
/** The maximum allowed width for any circuit element */
@@ -36,7 +36,7 @@ object CheckWidths extends Pass {
def run(c: Circuit): Circuit = {
val errors = new Errors()
- def check_width_w(info: Info, target: Target)(w: Width): Width = {
+ def check_width_w(info: Info, target: Target)(w: Width): Unit = {
w match {
case IntWidth(width) if width >= MaxWidth =>
errors.append(new WidthTooBig(info, target.serialize, width))
@@ -46,7 +46,6 @@ object CheckWidths extends Pass {
case _ =>
errors append new UninferredWidth(info, target.prettyPrint(" "))
}
- w
}
def hasWidth(tpe: Type): Boolean = tpe match {
@@ -55,18 +54,18 @@ object CheckWidths extends Pass {
case _ => throwInternalError(s"hasWidth - $tpe")
}
- def check_width_t(info: Info, target: Target)(t: Type): Type = {
- val tx = t match {
- case tt: BundleType => BundleType(tt.fields.map(check_width_f(info, target)))
- case tt => tt map check_width_t(info, target)
+ def check_width_t(info: Info, target: Target)(t: Type): Unit = {
+ t match {
+ case tt: BundleType => tt.fields.foreach(check_width_f(info, target))
+ case tt => tt foreach check_width_t(info, target)
}
- tx map check_width_w(info, target)
+ t foreach check_width_w(info, target)
}
- def check_width_f(info: Info, target: Target)(f: Field): Field = f
- .copy(tpe = check_width_t(info, target.modify(tokens = target.tokens :+ TargetToken.Field(f.name)))(f.tpe))
+ def check_width_f(info: Info, target: Target)(f: Field): Unit =
+ check_width_t(info, target.modify(tokens = target.tokens :+ TargetToken.Field(f.name)))(f.tpe)
- def check_width_e(info: Info, target: Target)(e: Expression): Expression = {
+ def check_width_e(info: Info, target: Target)(e: Expression): Unit = {
e match {
case e: UIntLiteral => e.width match {
case w: IntWidth if math.max(1, e.value.bitLength) > w.width =>
@@ -89,34 +88,36 @@ object CheckWidths extends Pass {
case _ =>
}
//e map check_width_t(info, mname) map check_width_e(info, mname)
- e map check_width_e(info, target)
+ e foreach check_width_e(info, target)
}
- def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Statement = {
+ def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Unit = {
val info = get_info(s) match { case NoInfo => minfo case x => x }
val subRef = s match { case sx: HasName => target.ref(sx.name) case _ => target }
- s map check_width_e(info, target) map check_width_s(info, target) map check_width_t(info, subRef) match {
+ s foreach check_width_e(info, target)
+ s foreach check_width_s(info, target)
+ s foreach check_width_t(info, subRef)
+ s match {
case Attach(infox, exprs) =>
exprs.tail.foreach ( e =>
if (bitWidth(e.tpe) != bitWidth(exprs.head.tpe))
errors.append(new AttachWidthsNotEqual(infox, target.serialize, e.serialize, exprs.head.serialize))
)
- s
case sx: DefRegister =>
sx.reset.tpe match {
case UIntType(IntWidth(w)) if w == 1 =>
case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name))
}
- s
- case _ => s
+ case _ =>
}
}
- def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Port = p.copy(tpe = check_width_t(p.info, target)(p.tpe))
+ def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target)(p.tpe)
def check_width_m(circuit: CircuitTarget)(m: DefModule) {
- m map check_width_p(m.info, circuit.module(m.name)) map check_width_s(m.info, circuit.module(m.name))
+ m foreach check_width_p(m.info, circuit.module(m.name))
+ m foreach check_width_s(m.info, circuit.module(m.name))
}
c.modules foreach check_width_m(CircuitTarget(c.main))
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 4c7458bf..bc9d3a1c 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -7,6 +7,7 @@ import firrtl.ir._
import firrtl.PrimOps._
import firrtl.Utils._
import firrtl.Mappers._
+import firrtl.traversals.Foreachers._
import firrtl.WrappedType._
object CheckHighForm extends Pass {
@@ -116,32 +117,29 @@ object CheckHighForm extends Pass {
case _ => // Do Nothing
}
- def checkHighFormW(info: Info, mname: String)(w: Width): Width = {
+ def checkHighFormW(info: Info, mname: String)(w: Width): Unit = {
w match {
- case wx: IntWidth if wx.width < 0 =>
- errors.append(new NegWidthException(info, mname))
+ case wx: IntWidth if wx.width < 0 => errors.append(new NegWidthException(info, mname))
case wx => // Do nothing
}
- w
}
- def checkHighFormT(info: Info, mname: String)(t: Type): Type =
- t map checkHighFormT(info, mname) match {
- case tx: VectorType if tx.size < 0 =>
- errors.append(new NegVecSizeException(info, mname))
- t
- case _ => t map checkHighFormW(info, mname)
+ def checkHighFormT(info: Info, mname: String)(t: Type): Unit = {
+ t foreach checkHighFormT(info, mname)
+ t match {
+ case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname))
+ case _ => t foreach checkHighFormW(info, mname)
}
+ }
- def validSubexp(info: Info, mname: String)(e: Expression): Expression = {
+ def validSubexp(info: Info, mname: String)(e: Expression): Unit = {
e match {
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error
case _ => errors.append(new InvalidAccessException(info, mname))
}
- e
}
- def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Expression = {
+ def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Unit = {
e match {
case ex: WRef if !names(ex.name) =>
errors.append(new UndeclaredReferenceException(info, mname, ex.name))
@@ -150,23 +148,23 @@ object CheckHighForm extends Pass {
case ex: DoPrim => checkHighFormPrimop(info, mname, ex)
case _: WRef | _: UIntLiteral | _: Mux | _: ValidIf =>
case ex: WSubAccess => validSubexp(info, mname)(ex.expr)
- case ex => ex map validSubexp(info, mname)
+ case ex => ex foreach validSubexp(info, mname)
}
- (e map checkHighFormW(info, mname)
- map checkHighFormT(info, mname)
- map checkHighFormE(info, mname, names))
+ e foreach checkHighFormW(info, mname)
+ e foreach checkHighFormT(info, mname)
+ e foreach checkHighFormE(info, mname, names)
}
- def checkName(info: Info, mname: String, names: NameSet)(name: String): String = {
+ def checkName(info: Info, mname: String, names: NameSet)(name: String): Unit = {
if (names(name))
errors.append(new NotUniqueException(info, mname, name))
names += name
- name
}
- def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Statement = {
+ def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Unit = {
val info = get_info(s) match {case NoInfo => minfo case x => x}
- s map checkName(info, mname, names) match {
+ s foreach checkName(info, mname, names)
+ s match {
case sx: DefMemory =>
if (hasFlip(sx.dataType))
errors.append(new MemWithFlipException(info, mname, sx.name))
@@ -184,24 +182,23 @@ object CheckHighForm extends Pass {
case sx: Print => checkFstring(info, mname, sx.string, sx.args.length)
case sx => // Do Nothing
}
- (s map checkHighFormT(info, mname)
- map checkHighFormE(info, mname, names)
- map checkHighFormS(minfo, mname, names))
+ s foreach checkHighFormT(info, mname)
+ s foreach checkHighFormE(info, mname, names)
+ s foreach checkHighFormS(minfo, mname, names)
}
- def checkHighFormP(mname: String, names: NameSet)(p: Port): Port = {
+ def checkHighFormP(mname: String, names: NameSet)(p: Port): Unit = {
if (names(p.name))
errors.append(new NotUniqueException(NoInfo, mname, p.name))
names += p.name
- (p.tpe map checkHighFormT(p.info, mname)
- map checkHighFormW(p.info, mname))
- p
+ p.tpe foreach checkHighFormT(p.info, mname)
+ p.tpe foreach checkHighFormW(p.info, mname)
}
def checkHighFormM(m: DefModule) {
val names = new NameSet
- (m map checkHighFormP(m.name, names)
- map checkHighFormS(m.info, m.name, names))
+ m foreach checkHighFormP(m.name, names)
+ m foreach checkHighFormS(m.info, m.name, names)
}
c.modules foreach checkHighFormM
@@ -333,7 +330,7 @@ object CheckTypes extends Pass {
}
}
- def check_types_e(info:Info, mname: String)(e: Expression): Expression = {
+ def check_types_e(info:Info, mname: String)(e: Expression): Unit = {
e match {
case (e: WSubField) => e.expr.tpe match {
case (t: BundleType) => t.fields find (_.name == e.name) match {
@@ -377,7 +374,7 @@ object CheckTypes extends Pass {
}
case _ =>
}
- e map check_types_e(info, mname)
+ e foreach check_types_e(info, mname)
}
def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = {
@@ -404,7 +401,7 @@ object CheckTypes extends Pass {
}
}
- def check_types_s(minfo: Info, mname: String)(s: Statement): Statement = {
+ def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = {
val info = get_info(s) match { case NoInfo => minfo case x => x }
s match {
case sx: Connect if wt(sx.loc.tpe) != wt(sx.expr.tpe) =>
@@ -457,10 +454,11 @@ object CheckTypes extends Pass {
}
case _ =>
}
- s map check_types_e(info, mname) map check_types_s(info, mname)
+ s foreach check_types_e(info, mname)
+ s foreach check_types_s(info, mname)
}
- c.modules foreach (m => m map check_types_s(m.info, m.name))
+ c.modules foreach (m => m foreach check_types_s(m.info, m.name))
errors.trigger()
c
}
@@ -504,7 +502,7 @@ object CheckGenders extends Pass {
flip_rec(t, Default)
}
- def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Expression = {
+ def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Unit = {
val gender = get_gender(e,genders)
(gender, desired) match {
case (MALE, FEMALE) =>
@@ -516,19 +514,18 @@ object CheckGenders extends Pass {
}
case _ =>
}
- e
}
- def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Expression = {
+ def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Unit = {
e match {
- case e: Mux => e map check_gender(info, mname, genders, MALE)
- case e: DoPrim => e.args map check_gender(info, mname, genders, MALE)
+ case e: Mux => e foreach check_gender(info, mname, genders, MALE)
+ case e: DoPrim => e.args foreach check_gender(info, mname, genders, MALE)
case _ =>
}
- e map check_genders_e(info, mname, genders)
+ e foreach check_genders_e(info, mname, genders)
}
- def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Statement = {
+ def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Unit = {
val info = get_info(s) match { case NoInfo => minfo case x => x }
s match {
case (s: DefWire) => genders(s.name) = BIGENDER
@@ -555,13 +552,14 @@ object CheckGenders extends Pass {
check_gender(info, mname, genders, MALE)(s.clk)
case _ =>
}
- s map check_genders_e(info, mname, genders) map check_genders_s(minfo, mname, genders)
+ s foreach check_genders_e(info, mname, genders)
+ s foreach check_genders_s(minfo, mname, genders)
}
for (m <- c.modules) {
val genders = new GenderMap
genders ++= (m.ports map (p => p.name -> to_gender(p.direction)))
- m map check_genders_s(m.info, m.name, genders)
+ m foreach check_genders_s(m.info, m.name, genders)
}
errors.trigger()
c
diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala
index 5fb837c1..a7ef9d43 100644
--- a/src/main/scala/firrtl/passes/memlib/MemIR.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala
@@ -31,4 +31,9 @@ case class DefAnnotatedMemory(
writeLatency, readLatency, readers, writers,
readwriters, readUnderWrite)
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = f(dataType)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}