diff options
Diffstat (limited to 'src/main/scala/firrtl/passes')
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckWidths.scala | 41 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 86 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/MemIR.scala | 5 |
3 files changed, 68 insertions, 64 deletions
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 4a72b18c..061c6b16 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -5,9 +5,9 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.PrimOps._ -import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.Utils._ -import firrtl.annotations.{Target, TargetToken, CircuitTarget, ModuleTarget} +import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken} object CheckWidths extends Pass { /** The maximum allowed width for any circuit element */ @@ -36,7 +36,7 @@ object CheckWidths extends Pass { def run(c: Circuit): Circuit = { val errors = new Errors() - def check_width_w(info: Info, target: Target)(w: Width): Width = { + def check_width_w(info: Info, target: Target)(w: Width): Unit = { w match { case IntWidth(width) if width >= MaxWidth => errors.append(new WidthTooBig(info, target.serialize, width)) @@ -46,7 +46,6 @@ object CheckWidths extends Pass { case _ => errors append new UninferredWidth(info, target.prettyPrint(" ")) } - w } def hasWidth(tpe: Type): Boolean = tpe match { @@ -55,18 +54,18 @@ object CheckWidths extends Pass { case _ => throwInternalError(s"hasWidth - $tpe") } - def check_width_t(info: Info, target: Target)(t: Type): Type = { - val tx = t match { - case tt: BundleType => BundleType(tt.fields.map(check_width_f(info, target))) - case tt => tt map check_width_t(info, target) + def check_width_t(info: Info, target: Target)(t: Type): Unit = { + t match { + case tt: BundleType => tt.fields.foreach(check_width_f(info, target)) + case tt => tt foreach check_width_t(info, target) } - tx map check_width_w(info, target) + t foreach check_width_w(info, target) } - def check_width_f(info: Info, target: Target)(f: Field): Field = f - .copy(tpe = check_width_t(info, target.modify(tokens = target.tokens :+ TargetToken.Field(f.name)))(f.tpe)) + def check_width_f(info: Info, target: Target)(f: Field): Unit = + check_width_t(info, target.modify(tokens = target.tokens :+ TargetToken.Field(f.name)))(f.tpe) - def check_width_e(info: Info, target: Target)(e: Expression): Expression = { + def check_width_e(info: Info, target: Target)(e: Expression): Unit = { e match { case e: UIntLiteral => e.width match { case w: IntWidth if math.max(1, e.value.bitLength) > w.width => @@ -89,34 +88,36 @@ object CheckWidths extends Pass { case _ => } //e map check_width_t(info, mname) map check_width_e(info, mname) - e map check_width_e(info, target) + e foreach check_width_e(info, target) } - def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Statement = { + def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Unit = { val info = get_info(s) match { case NoInfo => minfo case x => x } val subRef = s match { case sx: HasName => target.ref(sx.name) case _ => target } - s map check_width_e(info, target) map check_width_s(info, target) map check_width_t(info, subRef) match { + s foreach check_width_e(info, target) + s foreach check_width_s(info, target) + s foreach check_width_t(info, subRef) + s match { case Attach(infox, exprs) => exprs.tail.foreach ( e => if (bitWidth(e.tpe) != bitWidth(exprs.head.tpe)) errors.append(new AttachWidthsNotEqual(infox, target.serialize, e.serialize, exprs.head.serialize)) ) - s case sx: DefRegister => sx.reset.tpe match { case UIntType(IntWidth(w)) if w == 1 => case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) } - s - case _ => s + case _ => } } - def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Port = p.copy(tpe = check_width_t(p.info, target)(p.tpe)) + def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target)(p.tpe) def check_width_m(circuit: CircuitTarget)(m: DefModule) { - m map check_width_p(m.info, circuit.module(m.name)) map check_width_s(m.info, circuit.module(m.name)) + m foreach check_width_p(m.info, circuit.module(m.name)) + m foreach check_width_s(m.info, circuit.module(m.name)) } c.modules foreach check_width_m(CircuitTarget(c.main)) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 4c7458bf..bc9d3a1c 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -7,6 +7,7 @@ import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.WrappedType._ object CheckHighForm extends Pass { @@ -116,32 +117,29 @@ object CheckHighForm extends Pass { case _ => // Do Nothing } - def checkHighFormW(info: Info, mname: String)(w: Width): Width = { + def checkHighFormW(info: Info, mname: String)(w: Width): Unit = { w match { - case wx: IntWidth if wx.width < 0 => - errors.append(new NegWidthException(info, mname)) + case wx: IntWidth if wx.width < 0 => errors.append(new NegWidthException(info, mname)) case wx => // Do nothing } - w } - def checkHighFormT(info: Info, mname: String)(t: Type): Type = - t map checkHighFormT(info, mname) match { - case tx: VectorType if tx.size < 0 => - errors.append(new NegVecSizeException(info, mname)) - t - case _ => t map checkHighFormW(info, mname) + def checkHighFormT(info: Info, mname: String)(t: Type): Unit = { + t foreach checkHighFormT(info, mname) + t match { + case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname)) + case _ => t foreach checkHighFormW(info, mname) } + } - def validSubexp(info: Info, mname: String)(e: Expression): Expression = { + def validSubexp(info: Info, mname: String)(e: Expression): Unit = { e match { case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error case _ => errors.append(new InvalidAccessException(info, mname)) } - e } - def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Expression = { + def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Unit = { e match { case ex: WRef if !names(ex.name) => errors.append(new UndeclaredReferenceException(info, mname, ex.name)) @@ -150,23 +148,23 @@ object CheckHighForm extends Pass { case ex: DoPrim => checkHighFormPrimop(info, mname, ex) case _: WRef | _: UIntLiteral | _: Mux | _: ValidIf => case ex: WSubAccess => validSubexp(info, mname)(ex.expr) - case ex => ex map validSubexp(info, mname) + case ex => ex foreach validSubexp(info, mname) } - (e map checkHighFormW(info, mname) - map checkHighFormT(info, mname) - map checkHighFormE(info, mname, names)) + e foreach checkHighFormW(info, mname) + e foreach checkHighFormT(info, mname) + e foreach checkHighFormE(info, mname, names) } - def checkName(info: Info, mname: String, names: NameSet)(name: String): String = { + def checkName(info: Info, mname: String, names: NameSet)(name: String): Unit = { if (names(name)) errors.append(new NotUniqueException(info, mname, name)) names += name - name } - def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Statement = { + def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Unit = { val info = get_info(s) match {case NoInfo => minfo case x => x} - s map checkName(info, mname, names) match { + s foreach checkName(info, mname, names) + s match { case sx: DefMemory => if (hasFlip(sx.dataType)) errors.append(new MemWithFlipException(info, mname, sx.name)) @@ -184,24 +182,23 @@ object CheckHighForm extends Pass { case sx: Print => checkFstring(info, mname, sx.string, sx.args.length) case sx => // Do Nothing } - (s map checkHighFormT(info, mname) - map checkHighFormE(info, mname, names) - map checkHighFormS(minfo, mname, names)) + s foreach checkHighFormT(info, mname) + s foreach checkHighFormE(info, mname, names) + s foreach checkHighFormS(minfo, mname, names) } - def checkHighFormP(mname: String, names: NameSet)(p: Port): Port = { + def checkHighFormP(mname: String, names: NameSet)(p: Port): Unit = { if (names(p.name)) errors.append(new NotUniqueException(NoInfo, mname, p.name)) names += p.name - (p.tpe map checkHighFormT(p.info, mname) - map checkHighFormW(p.info, mname)) - p + p.tpe foreach checkHighFormT(p.info, mname) + p.tpe foreach checkHighFormW(p.info, mname) } def checkHighFormM(m: DefModule) { val names = new NameSet - (m map checkHighFormP(m.name, names) - map checkHighFormS(m.info, m.name, names)) + m foreach checkHighFormP(m.name, names) + m foreach checkHighFormS(m.info, m.name, names) } c.modules foreach checkHighFormM @@ -333,7 +330,7 @@ object CheckTypes extends Pass { } } - def check_types_e(info:Info, mname: String)(e: Expression): Expression = { + def check_types_e(info:Info, mname: String)(e: Expression): Unit = { e match { case (e: WSubField) => e.expr.tpe match { case (t: BundleType) => t.fields find (_.name == e.name) match { @@ -377,7 +374,7 @@ object CheckTypes extends Pass { } case _ => } - e map check_types_e(info, mname) + e foreach check_types_e(info, mname) } def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { @@ -404,7 +401,7 @@ object CheckTypes extends Pass { } } - def check_types_s(minfo: Info, mname: String)(s: Statement): Statement = { + def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = { val info = get_info(s) match { case NoInfo => minfo case x => x } s match { case sx: Connect if wt(sx.loc.tpe) != wt(sx.expr.tpe) => @@ -457,10 +454,11 @@ object CheckTypes extends Pass { } case _ => } - s map check_types_e(info, mname) map check_types_s(info, mname) + s foreach check_types_e(info, mname) + s foreach check_types_s(info, mname) } - c.modules foreach (m => m map check_types_s(m.info, m.name)) + c.modules foreach (m => m foreach check_types_s(m.info, m.name)) errors.trigger() c } @@ -504,7 +502,7 @@ object CheckGenders extends Pass { flip_rec(t, Default) } - def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Expression = { + def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Unit = { val gender = get_gender(e,genders) (gender, desired) match { case (MALE, FEMALE) => @@ -516,19 +514,18 @@ object CheckGenders extends Pass { } case _ => } - e } - def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Expression = { + def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Unit = { e match { - case e: Mux => e map check_gender(info, mname, genders, MALE) - case e: DoPrim => e.args map check_gender(info, mname, genders, MALE) + case e: Mux => e foreach check_gender(info, mname, genders, MALE) + case e: DoPrim => e.args foreach check_gender(info, mname, genders, MALE) case _ => } - e map check_genders_e(info, mname, genders) + e foreach check_genders_e(info, mname, genders) } - def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Statement = { + def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Unit = { val info = get_info(s) match { case NoInfo => minfo case x => x } s match { case (s: DefWire) => genders(s.name) = BIGENDER @@ -555,13 +552,14 @@ object CheckGenders extends Pass { check_gender(info, mname, genders, MALE)(s.clk) case _ => } - s map check_genders_e(info, mname, genders) map check_genders_s(minfo, mname, genders) + s foreach check_genders_e(info, mname, genders) + s foreach check_genders_s(minfo, mname, genders) } for (m <- c.modules) { val genders = new GenderMap genders ++= (m.ports map (p => p.name -> to_gender(p.direction))) - m map check_genders_s(m.info, m.name, genders) + m foreach check_genders_s(m.info, m.name, genders) } errors.trigger() c diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala index 5fb837c1..a7ef9d43 100644 --- a/src/main/scala/firrtl/passes/memlib/MemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala @@ -31,4 +31,9 @@ case class DefAnnotatedMemory( writeLatency, readLatency, readers, writers, readwriters, readUnderWrite) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(dataType) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } |
