diff options
| author | Donggyu | 2016-09-13 17:26:11 -0700 |
|---|---|---|
| committer | GitHub | 2016-09-13 17:26:11 -0700 |
| commit | 1bb9597a01e77d9a1ece479caf13cf6c3f6229d5 (patch) | |
| tree | 39b3dc1da954faea65777eb595e64fbd2b1a2f45 /src | |
| parent | 96340374f091d5258ca69ef7fc614910e1c2cbb7 (diff) | |
| parent | 1cfda487ec6773a139587c1c0bcf145c03b46800 (diff) | |
Merge pull request #285 from ucb-bar/more_passes_cleanups
More passes cleanups
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckChirrtl.scala | 21 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 32 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 15 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 292 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/MemUtils.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/PadWidths.scala | 130 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 163 |
7 files changed, 276 insertions, 379 deletions
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 2ab8749b..f21449a2 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -27,21 +27,14 @@ MODIFICATIONS. package firrtl.passes -import com.typesafe.scalalogging.LazyLogging - -// Datastructures -import scala.collection.mutable.HashSet - import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -import firrtl.PrimOps._ -import firrtl.WrappedType._ - -object CheckChirrtl extends Pass with LazyLogging { +object CheckChirrtl extends Pass { def name = "Chirrtl Check" + type NameSet = collection.mutable.HashSet[String] class NotUniqueException(info: Info, mname: String, name: String) extends PassException( s"${info}: [module ${mname}] Reference ${name} does not have a unique name.") @@ -101,7 +94,7 @@ object CheckChirrtl extends Pass with LazyLogging { e } - def checkChirrtlE(info: Info, mname: String, names: HashSet[String])(e: Expression): Expression = { + def checkChirrtlE(info: Info, mname: String, names: NameSet)(e: Expression): Expression = { e match { case _: DoPrim | _:Mux | _:ValidIf | _: UIntLiteral => case e: Reference if !names(e.name) => @@ -114,14 +107,14 @@ object CheckChirrtl extends Pass with LazyLogging { map checkChirrtlE(info, mname, names)) } - def checkName(info: Info, mname: String, names: HashSet[String])(name: String): String = { + def checkName(info: Info, mname: String, names: NameSet)(name: String): String = { if (names(name)) errors append (new NotUniqueException(info, mname, name)) names += name name } - def checkChirrtlS(minfo: Info, mname: String, names: HashSet[String])(s: Statement): Statement = { + def checkChirrtlS(minfo: Info, mname: String, names: NameSet)(s: Statement): Statement = { val info = get_info(s) match {case NoInfo => minfo case x => x} (s map checkName(info, mname, names)) match { case s: DefMemory => @@ -138,7 +131,7 @@ object CheckChirrtl extends Pass with LazyLogging { map checkChirrtlS(info, mname, names)) } - def checkChirrtlP(mname: String, names: HashSet[String])(p: Port): Port = { + def checkChirrtlP(mname: String, names: NameSet)(p: Port): Port = { names += p.name (p.tpe map checkChirrtlT(p.info, mname) map checkChirrtlW(p.info, mname)) @@ -146,7 +139,7 @@ object CheckChirrtl extends Pass with LazyLogging { } def checkChirrtlM(m: DefModule) { - val names = HashSet[String]() + val names = new NameSet (m map checkChirrtlP(m.name, names) map checkChirrtlS(m.info, m.name, names)) } diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 16b16ff7..c300f7c7 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -27,8 +27,6 @@ MODIFICATIONS. package firrtl.passes -import com.typesafe.scalalogging.LazyLogging - import firrtl._ import firrtl.ir._ import firrtl.PrimOps._ @@ -36,11 +34,9 @@ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.WrappedType._ -// Datastructures -import scala.collection.mutable.{HashMap, HashSet} - -object CheckHighForm extends Pass with LazyLogging { +object CheckHighForm extends Pass { def name = "High Form Check" + type NameSet = collection.mutable.HashSet[String] // Custom Exceptions class NotUniqueException(info: Info, mname: String, name: String) extends PassException( @@ -160,7 +156,7 @@ object CheckHighForm extends Pass with LazyLogging { e } - def checkHighFormE(info: Info, mname: String, names: HashSet[String])(e: Expression): Expression = { + def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Expression = { e match { case e: WRef if !names(e.name) => errors append new UndeclaredReferenceException(info, mname, e.name) @@ -176,14 +172,14 @@ object CheckHighForm extends Pass with LazyLogging { map checkHighFormE(info, mname, names)) } - def checkName(info: Info, mname: String, names: HashSet[String])(name: String): String = { + def checkName(info: Info, mname: String, names: NameSet)(name: String): String = { if (names(name)) errors append new NotUniqueException(info, mname, name) names += name name } - def checkHighFormS(minfo: Info, mname: String, names: HashSet[String])(s: Statement): Statement = { + def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Statement = { val info = get_info(s) match {case NoInfo => minfo case x => x} (s map checkName(info, mname, names)) match { case s: DefMemory => @@ -208,7 +204,7 @@ object CheckHighForm extends Pass with LazyLogging { map checkHighFormS(minfo, mname, names)) } - def checkHighFormP(mname: String, names: HashSet[String])(p: Port): Port = { + def checkHighFormP(mname: String, names: NameSet)(p: Port): Port = { names += p.name (p.tpe map checkHighFormT(p.info, mname) map checkHighFormW(p.info, mname)) @@ -216,7 +212,7 @@ object CheckHighForm extends Pass with LazyLogging { } def checkHighFormM(m: DefModule) { - val names = HashSet[String]() + val names = new NameSet (m map checkHighFormP(m.name, names) map checkHighFormS(m.info, m.name, names)) } @@ -231,7 +227,7 @@ object CheckHighForm extends Pass with LazyLogging { } } -object CheckTypes extends Pass with LazyLogging { +object CheckTypes extends Pass { def name = "Check Types" // Custom Exceptions @@ -430,6 +426,7 @@ object CheckTypes extends Pass with LazyLogging { object CheckGenders extends Pass { def name = "Check Genders" + type GenderMap = collection.mutable.HashMap[String, Gender] implicit def toStr(g: Gender): String = g match { case MALE => "source" @@ -444,7 +441,7 @@ object CheckGenders extends Pass { def run (c:Circuit): Circuit = { val errors = new Errors() - def get_gender(e: Expression, genders: HashMap[String, Gender]): Gender = e match { + def get_gender(e: Expression, genders: GenderMap): Gender = e match { case (e: WRef) => genders(e.name) case (e: WSubIndex) => get_gender(e.exp, genders) case (e: WSubAccess) => get_gender(e.exp, genders) @@ -466,8 +463,7 @@ object CheckGenders extends Pass { flip_rec(t, Default) } - def check_gender(info:Info, mname: String, - genders: HashMap[String,Gender], desired: Gender)(e:Expression): Expression = { + def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Expression = { val gender = get_gender(e,genders) (gender, desired) match { case (MALE, FEMALE) => @@ -482,7 +478,7 @@ object CheckGenders extends Pass { e } - def check_genders_e (info:Info, mname: String, genders: HashMap[String,Gender])(e:Expression): Expression = { + def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Expression = { e match { case e: Mux => e map check_gender(info, mname, genders, MALE) case e: DoPrim => e.args map check_gender(info, mname, genders, MALE) @@ -491,7 +487,7 @@ object CheckGenders extends Pass { e map check_genders_e(info, mname, genders) } - def check_genders_s(minfo: Info, mname: String, genders: HashMap[String,Gender])(s: Statement): Statement = { + def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Statement = { val info = get_info(s) match { case NoInfo => minfo case x => x } s match { case (s: DefWire) => genders(s.name) = BIGENDER @@ -522,7 +518,7 @@ object CheckGenders extends Pass { } for (m <- c.modules) { - val genders = HashMap[String, Gender]() + val genders = new GenderMap genders ++= (m.ports map (p => p.name -> to_gender(p.direction))) m map check_genders_s(m.info, m.name, genders) } diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 6b2ff6ed..ebec4d80 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -28,20 +28,19 @@ MODIFICATIONS. package firrtl.passes // Datastructures -import scala.collection.mutable.{LinkedHashMap, HashMap, HashSet, ArrayBuffer} +import scala.collection.mutable.ArrayBuffer import scala.collection.immutable.ListMap import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -import firrtl.PrimOps._ -import firrtl.WrappedExpression._ object InferWidths extends Pass { def name = "Infer Widths" + type ConstraintMap = collection.mutable.LinkedHashMap[String, Width] - def solve_constraints(l: Seq[WGeq]): LinkedHashMap[String, Width] = { + def solve_constraints(l: Seq[WGeq]): ConstraintMap = { def unique(ls: Seq[Width]) : Seq[Width] = (ls map (new WrappedWidth(_))).distinct map (_.w) def make_unique(ls: Seq[WGeq]): ListMap[String,Width] = { @@ -77,7 +76,7 @@ object InferWidths extends Pass { case _ => w } - def substitute(h: LinkedHashMap[String, Width])(w: Width): Width = { + def substitute(h: ConstraintMap)(w: Width): Width = { //;println-all-debug(["Substituting for [" w "]"]) val wx = simplify(w) //;println-all-debug(["After Simplify: [" wx "]"]) @@ -98,7 +97,7 @@ object InferWidths extends Pass { } } - def b_sub(h: LinkedHashMap[String, Width])(w: Width): Width = { + def b_sub(h: ConstraintMap)(w: Width): Width = { w map b_sub(h) match { case w: VarWidth => h getOrElse (w.name, w) case w => w @@ -145,7 +144,7 @@ object InferWidths extends Pass { //for (x <- u) { println(x) } //println("====================================") - val f = LinkedHashMap[String, Width]() + val f = new ConstraintMap val o = ArrayBuffer[String]() for ((n, e) <- u) { //println("==== SOLUTIONS TABLE ====") @@ -175,7 +174,7 @@ object InferWidths extends Pass { //for (x <- f) println(x) //; Backwards Solve - val b = LinkedHashMap[String, Width]() + val b = new ConstraintMap for (i <- (o.size - 1) to 0 by -1) { val n = o(i) // Should visit `o` backward /* diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 57f8fd76..b3969bea 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -27,16 +27,11 @@ MODIFICATIONS. package firrtl.passes -import com.typesafe.scalalogging.LazyLogging - import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -// Datastructures -import scala.collection.mutable.HashMap - /** Removes all aggregate types from a [[firrtl.ir.Circuit]] * * @note Assumes [[firrtl.ir.SubAccess]]es have been removed @@ -67,8 +62,8 @@ object LowerTypes extends Pass { def loweredName(s: Seq[String]): String = s mkString delim private case class LowerTypesException(msg: String) extends FIRRTLException(msg) - private def error(msg: String)(implicit sinfo: Info, mname: String) = - throw new LowerTypesException(s"$sinfo: [module $mname] $msg") + private def error(msg: String)(info: Info, mname: String) = + throw LowerTypesException(s"$info: [module $mname] $msg") // TODO Improve? Probably not the best way to do this private def splitMemRef(e1: Expression): (WRef, WRef, WRef, Option[Expression]) = { @@ -83,165 +78,146 @@ object LowerTypes extends Pass { } } - // Everything wrapped in run so that it's thread safe - def run(c: Circuit): Circuit = { - // Debug state - implicit var mname: String = "" - implicit var sinfo: Info = NoInfo - - def lowerTypes(m: DefModule): DefModule = { - val memDataTypeMap = HashMap[String, Type]() - - // Lowers an expression of MemKind - // Since mems with Bundle type must be split into multiple ground type - // mem, references to fields addr, en, clk, and rmode must be replicated - // for each resulting memory - // References to data, mask, rdata, wdata, and wmask have already been split in expand connects - // and just need to be converted to refer to the correct new memory - def lowerTypesMemExp(e: Expression): Seq[Expression] = { - val (mem, port, field, tail) = splitMemRef(e) - field.name match { - // Fields that need to be replicated for each resulting mem - case "addr" | "en" | "clk" | "wmode" => - require(tail.isEmpty) // there can't be a tail for these - memDataTypeMap(mem.name) match { - case _: GroundType => Seq(e) - case memType => create_exps(mem.name, memType) map { e => - val loMemName = loweredName(e) - val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) - mergeRef(loMem, mergeRef(port, field)) - } - } - // Fields that need not be replicated for each - // eg. mem.reader.data[0].a - // (Connect/IsInvalid must already have been split to ground types) - case "data" | "mask" | "rdata" | "wdata" | "wmask" => - val loMem = tail match { - case Some(e) => - val loMemExp = mergeRef(mem, e) - val loMemName = loweredName(loMemExp) - WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) - case None => mem - } - Seq(mergeRef(loMem, mergeRef(port, field))) - case name => error(s"Error! Unhandled memory field ${name}") + // Lowers an expression of MemKind + // Since mems with Bundle type must be split into multiple ground type + // mem, references to fields addr, en, clk, and rmode must be replicated + // for each resulting memory + // References to data, mask, rdata, wdata, and wmask have already been split in expand connects + // and just need to be converted to refer to the correct new memory + type MemDataTypeMap = collection.mutable.HashMap[String, Type] + def lowerTypesMemExp(memDataTypeMap: MemDataTypeMap, + info: Info, mname: String)(e: Expression): Seq[Expression] = { + val (mem, port, field, tail) = splitMemRef(e) + field.name match { + // Fields that need to be replicated for each resulting mem + case "addr" | "en" | "clk" | "wmode" => + require(tail.isEmpty) // there can't be a tail for these + memDataTypeMap(mem.name) match { + case _: GroundType => Seq(e) + case memType => create_exps(mem.name, memType) map { e => + val loMemName = loweredName(e) + val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) + mergeRef(loMem, mergeRef(port, field)) + } } - } + // Fields that need not be replicated for each + // eg. mem.reader.data[0].a + // (Connect/IsInvalid must already have been split to ground types) + case "data" | "mask" | "rdata" | "wdata" | "wmask" => + val loMem = tail match { + case Some(e) => + val loMemExp = mergeRef(mem, e) + val loMemName = loweredName(loMemExp) + WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) + case None => mem + } + Seq(mergeRef(loMem, mergeRef(port, field))) + case name => error(s"Error! Unhandled memory field ${name}")(info, mname) + } + } - def lowerTypesExp(e: Expression): Expression = e match { - case e: WRef => e - case (_: WSubField | _: WSubIndex) => kind(e) match { - case k: InstanceKind => - val (root, tail) = splitRef(e) - val name = loweredName(tail) - WSubField(root, name, e.tpe, gender(e)) - case k: MemKind => - val exps = lowerTypesMemExp(e) - exps.size match { - case 1 => exps.head - case _ => error("Error! lowerTypesExp called on MemKind " + - "SubField that needs to be expanded!") - } - case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e)) + def lowerTypesExp(memDataTypeMap: MemDataTypeMap, + info: Info, mname: String)(e: Expression): Expression = e match { + case e: WRef => e + case (_: WSubField | _: WSubIndex) => kind(e) match { + case k: InstanceKind => + val (root, tail) = splitRef(e) + val name = loweredName(tail) + WSubField(root, name, e.tpe, gender(e)) + case k: MemKind => + val exps = lowerTypesMemExp(memDataTypeMap, info, mname)(e) + exps.size match { + case 1 => exps.head + case _ => error("Error! lowerTypesExp called on MemKind " + + "SubField that needs to be expanded!")(info, mname) } - case e: Mux => e map (lowerTypesExp) - case e: ValidIf => e map (lowerTypesExp) - case e: DoPrim => e map (lowerTypesExp) - case e @ (_: UIntLiteral | _: SIntLiteral) => e - } + case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e)) + } + case e: Mux => e map lowerTypesExp(memDataTypeMap, info, mname) + case e: ValidIf => e map lowerTypesExp(memDataTypeMap, info, mname) + case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname) + case e @ (_: UIntLiteral | _: SIntLiteral) => e + } - def lowerTypesStmt(s: Statement): Statement = s map lowerTypesStmt match { - case s: DefWire => - sinfo = s.info - s.tpe match { - case _: GroundType => s - case _ => Block(create_exps(s.name, s.tpe) map ( - e => DefWire(s.info, loweredName(e), e.tpe))) - } - case s: DefRegister => - sinfo = s.info - s.tpe match { - case _: GroundType => s map lowerTypesExp - case _ => - val es = create_exps(s.name, s.tpe) - val inits = create_exps(s.init) map (lowerTypesExp) - val clock = lowerTypesExp(s.clock) - val reset = lowerTypesExp(s.reset) - Block(es zip inits map { case (e, i) => - DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) - }) - } - // Could instead just save the type of each Module as it gets processed - case s: WDefInstance => - sinfo = s.info - s.tpe match { - case t: BundleType => - val fieldsx = t.fields flatMap (f => - create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map ( - // Flip because inst genders are reversed from Module type - e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe) - ) - ) - WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) - case _ => error("WDefInstance type should be Bundle!") - } - case s: DefMemory => - sinfo = s.info - memDataTypeMap(s.name) = s.dataType - s.dataType match { - case _: GroundType => s - case _ => Block(create_exps(s.name, s.dataType) map (e => - DefMemory(s.info, loweredName(e), e.tpe, s.depth, - s.writeLatency, s.readLatency, s.readers, s.writers, - s.readwriters))) - } - // wire foo : { a , b } - // node x = foo - // node y = x.a - // -> - // node x_a = foo_a - // node x_b = foo_b - // node y = x_a - case s: DefNode => - sinfo = s.info - val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp) - val exps = create_exps(s.value) map (lowerTypesExp) - Block(names zip exps map {case (n, e) => DefNode(s.info, loweredName(n), e)}) - case s: IsInvalid => - sinfo = s.info - kind(s.expr) match { - case k: MemKind => - Block(lowerTypesMemExp(s.expr) map (IsInvalid(s.info, _))) - case _ => s map (lowerTypesExp) - } - case s: Connect => - sinfo = s.info - kind(s.loc) match { - case k: MemKind => - val exp = lowerTypesExp(s.expr) - val locs = lowerTypesMemExp(s.loc) - Block(locs map (Connect(s.info, _, exp))) - case _ => s map (lowerTypesExp) - } - case s => s map (lowerTypesExp) + def lowerTypesStmt(memDataTypeMap: MemDataTypeMap, + minfo: Info, mname: String)(s: Statement): Statement = { + val info = get_info(s) match {case NoInfo => minfo case x => x} + s map lowerTypesStmt(memDataTypeMap, info, mname) match { + case s: DefWire => s.tpe match { + case _: GroundType => s + case _ => Block(create_exps(s.name, s.tpe) map ( + e => DefWire(s.info, loweredName(e), e.tpe))) } - - sinfo = m.info - mname = m.name - // Lower Ports - val portsx = m.ports flatMap ( p => - create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) map ( - e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe) - ) - ) - m match { - case m: ExtModule => m.copy(ports = portsx) - case m: Module => Module(m.info, m.name, portsx, lowerTypesStmt(m.body)) + case s: DefRegister => s.tpe match { + case _: GroundType => s map lowerTypesExp(memDataTypeMap, info, mname) + case _ => + val es = create_exps(s.name, s.tpe) + val inits = create_exps(s.init) map lowerTypesExp(memDataTypeMap, info, mname) + val clock = lowerTypesExp(memDataTypeMap, info, mname)(s.clock) + val reset = lowerTypesExp(memDataTypeMap, info, mname)(s.reset) + Block(es zip inits map { case (e, i) => + DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) + }) + } + // Could instead just save the type of each Module as it gets processed + case s: WDefInstance => s.tpe match { + case t: BundleType => + val fieldsx = t.fields flatMap (f => + create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map ( + // Flip because inst genders are reversed from Module type + e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe))) + WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) + case _ => error("WDefInstance type should be Bundle!")(info, mname) } + case s: DefMemory => + memDataTypeMap(s.name) = s.dataType + s.dataType match { + case _: GroundType => s + case _ => Block(create_exps(s.name, s.dataType) map (e => + s copy (name = loweredName(e), dataType = e.tpe))) + } + // wire foo : { a , b } + // node x = foo + // node y = x.a + // -> + // node x_a = foo_a + // node x_b = foo_b + // node y = x_a + case s: DefNode => + val names = create_exps(s.name, s.value.tpe) map lowerTypesExp(memDataTypeMap, info, mname) + val exps = create_exps(s.value) map lowerTypesExp(memDataTypeMap, info, mname) + Block(names zip exps map { case (n, e) => DefNode(info, loweredName(n), e) }) + case s: IsInvalid => kind(s.expr) match { + case _: MemKind => + Block(lowerTypesMemExp(memDataTypeMap, info, mname)(s.expr) map (IsInvalid(info, _))) + case _ => s map lowerTypesExp(memDataTypeMap, info, mname) + } + case s: Connect => kind(s.loc) match { + case k: MemKind => + val exp = lowerTypesExp(memDataTypeMap, info, mname)(s.expr) + val locs = lowerTypesMemExp(memDataTypeMap, info, mname)(s.loc) + Block(locs map (Connect(info, _, exp))) + case _ => s map lowerTypesExp(memDataTypeMap, info, mname) + } + case s => s map lowerTypesExp(memDataTypeMap, info, mname) } + } - sinfo = c.info - Circuit(c.info, c.modules map lowerTypes, c.main) + def lowerTypes(m: DefModule): DefModule = { + val memDataTypeMap = new MemDataTypeMap + // Lower Ports + val portsx = m.ports flatMap { p => + val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) + exps map (e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe)) + } + m match { + case m: ExtModule => + m copy (ports = portsx) + case m: Module => + m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name) + } } + + def run(c: Circuit): Circuit = c copy (modules = (c.modules map lowerTypes)) } diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala index 87033176..57a7120b 100644 --- a/src/main/scala/firrtl/passes/MemUtils.scala +++ b/src/main/scala/firrtl/passes/MemUtils.scala @@ -27,8 +27,6 @@ package firrtl.passes -import com.typesafe.scalalogging.LazyLogging - import firrtl._ import firrtl.ir._ import firrtl.Utils._ diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index bef9ac33..4c198bab 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -7,80 +7,58 @@ import firrtl.Mappers._ // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { - def name = "Pad Widths" - private def width(t: Type): Int = bitWidth(t).toInt - private def width(e: Expression): Int = width(e.tpe) - // Returns an expression with the correct integer width - private def fixup(i: Int)(e: Expression) = { - def tx = e.tpe match { - case t: UIntType => UIntType(IntWidth(i)) - case t: SIntType => SIntType(IntWidth(i)) - // default case should never be reached - } - if (i > width(e)) { - DoPrim(Pad, Seq(e), Seq(i), tx) - } else if (i < width(e)) { - val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i))) - // Bit Select always returns UInt, cast if selecting from SInt - e.tpe match { - case UIntType(_) => e2 - case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i))) - } - } else { - e - } - } - // Recursive, updates expression so children exp's have correct widths - private def onExp(e: Expression): Expression = { - val sensitiveOps = Seq( Lt, Leq, Gt, Geq, Eq, Neq, Not, And, Or, Xor, - Add, Sub, Mul, Div, Rem, Shr) - val x = e map onExp - x match { - case Mux(cond, tval, fval, tpe) => { - val tvalx = fixup(width(tpe))(tval) - val fvalx = fixup(width(tpe))(fval) - Mux(cond, tvalx, fvalx, tpe) - } - case DoPrim(op, args, consts, tpe) => op match { - case _ if sensitiveOps.contains(op) => { - val i = args.map(a => width(a)).foldLeft(0) {(a, b) => math.max(a, b)} - x map fixup(i) - } - case Dshl => { - // special case as args aren't all same width - val ax = fixup(width(tpe))(args(0)) - DoPrim(Dshlw, Seq(ax, args(1)), consts, tpe) - } - case Shl => { - // special case as arg should be same width as result - val ax = fixup(width(tpe))(args(0)) - DoPrim(Shlw, Seq(ax), consts, tpe) - } - case _ => x - } - case ValidIf(cond, value, tpe) => ValidIf(cond, fixup(width(tpe))(value), tpe) - case x => x - } - } - // Recursive. Fixes assignments and register initialization widths - private def onStmt(s: Statement): Statement = { - s map onExp match { - case s: Connect => { - val ex = fixup(width(s.loc))(s.expr) - Connect(s.info, s.loc, ex) - } - case s: DefRegister => { - val ex = fixup(width(s.tpe))(s.init) - DefRegister(s.info, s.name, s.tpe, s.clock, s.reset, ex) - } - case s => s map onStmt - } - } - private def onModule(m: DefModule): DefModule = { - m match { - case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body)) - case m: ExtModule => m - } - } - def run(c: Circuit): Circuit = Circuit(c.info, c.modules.map(onModule _), c.main) + def name = "Pad Widths" + private def width(t: Type): Int = bitWidth(t).toInt + private def width(e: Expression): Int = width(e.tpe) + // Returns an expression with the correct integer width + private def fixup(i: Int)(e: Expression) = { + def tx = e.tpe match { + case t: UIntType => UIntType(IntWidth(i)) + case t: SIntType => SIntType(IntWidth(i)) + // default case should never be reached + } + width(e) match { + case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx) + case j if i < j => + val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i))) + // Bit Select always returns UInt, cast if selecting from SInt + e.tpe match { + case UIntType(_) => e2 + case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i))) + } + case _ => e + } + } + + // Recursive, updates expression so children exp's have correct widths + private def onExp(e: Expression): Expression = e map onExp match { + case Mux(cond, tval, fval, tpe) => + Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe) + case e: ValidIf => e copy (value = fixup(width(e.tpe))(e.value)) + case e: DoPrim => e.op match { + case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | + Add | Sub | Mul | Div | Rem | Shr => + // sensitive ops + e map fixup((e.args map (width(_)) foldLeft 0)(math.max(_, _))) + case Dshl => + // special case as args aren't all same width + e copy (op = Dshlw, args = Seq(fixup(width(e.tpe))(e.args(0)), e.args(1))) + case Shl => + // special case as arg should be same width as result + e copy (op = Shlw, args = Seq(fixup(width(e.tpe))(e.args(0)))) + case _ => e + } + case e => e + } + + // Recursive. Fixes assignments and register initialization widths + private def onStmt(s: Statement): Statement = s map onExp match { + case s: Connect => + s copy (expr = fixup(width(s.loc))(s.expr)) + case s: DefRegister => + s copy (init = fixup(width(s.tpe))(s.init)) + case s => s map onStmt + } + + def run(c: Circuit): Circuit = c copy (modules = (c.modules map (_ map onStmt))) } diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index d5d9a3b6..965ae339 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -28,20 +28,12 @@ MODIFICATIONS. package firrtl.passes import com.typesafe.scalalogging.LazyLogging -import java.nio.file.{Paths, Files} - -// Datastructures -import scala.collection.mutable.LinkedHashMap -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.ArrayBuffer import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ -import firrtl.WrappedExpression._ trait Pass extends LazyLogging { def name: String @@ -52,7 +44,7 @@ trait Pass extends LazyLogging { class PassException(message: String) extends Exception(message) class PassExceptions(exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n")) class Errors { - val errors = ArrayBuffer[PassException]() + val errors = collection.mutable.ArrayBuffer[PassException]() def append(pe: PassException) = errors.append(pe) def trigger = errors.size match { case 0 => @@ -65,33 +57,23 @@ class Errors { // These should be distributed into separate files object ToWorkingIR extends Pass { - private var mname = "" - def name = "Working IR" - def run (c:Circuit): Circuit = { - def toExp (e:Expression) : Expression = { - e map (toExp) match { - case e:Reference => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER) - case e:SubField => WSubField(e.expr, e.name, e.tpe, UNKNOWNGENDER) - case e:SubIndex => WSubIndex(e.expr, e.value, e.tpe, UNKNOWNGENDER) - case e:SubAccess => WSubAccess(e.expr, e.index, e.tpe, UNKNOWNGENDER) - case e => e - } - } - def toStmt (s:Statement) : Statement = { - s map (toExp) match { - case s:DefInstance => WDefInstance(s.info,s.name,s.module,UnknownType) - case s => s map (toStmt) - } - } - val modulesx = c.modules.map { m => - mname = m.name - m match { - case m:Module => Module(m.info,m.name, m.ports, toStmt(m.body)) - case m:ExtModule => m - } - } - Circuit(c.info,modulesx,c.main) - } + def name = "Working IR" + + def toExp(e:Expression) : Expression = e map (toExp) match { + case e: Reference => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER) + case e: SubField => WSubField(e.expr, e.name, e.tpe, UNKNOWNGENDER) + case e: SubIndex => WSubIndex(e.expr, e.value, e.tpe, UNKNOWNGENDER) + case e: SubAccess => WSubAccess(e.expr, e.index, e.tpe, UNKNOWNGENDER) + case e => e + } + + def toStmt(s: Statement): Statement = s map (toExp) match { + case s: DefInstance => WDefInstance(s.info, s.name, s.module, UnknownType) + case s => s map (toStmt) + } + + def run (c:Circuit): Circuit = + c copy (modules = (c.modules map (_ map toStmt))) } object PullMuxes extends Pass { @@ -140,7 +122,7 @@ object ExpandConnects extends Pass { def name = "Expand Connects" def run(c: Circuit): Circuit = { def expand_connects(m: Module): Module = { - val genders = LinkedHashMap[String,Gender]() + val genders = collection.mutable.LinkedHashMap[String,Gender]() def expand_s(s: Statement): Statement = { def set_gender(e: Expression): Expression = e map (set_gender) match { case (e: WRef) => WRef(e.name, e.tpe, e.kind, genders(e.name)) @@ -276,78 +258,53 @@ object Legalize extends Pass { } legalizedStmt map legalizeS map legalizeE } - def legalizeM (m: DefModule): DefModule = m map (legalizeS) - Circuit(c.info, c.modules.map(legalizeM), c.main) + c copy (modules = (c.modules map (_ map legalizeS))) } } object VerilogWrap extends Pass { - def name = "Verilog Wrap" - var mname = "" - def v_wrap_e (e:Expression) : Expression = { - e map (v_wrap_e) match { - case (e:DoPrim) => { - def a0 () = e.args(0) - if (e.op == Tail) { - (a0()) match { - case (e0:DoPrim) => { - if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),e.tpe) - else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),e.tpe) - else e - } - case (e0) => e - } - } - else e - } - case (e) => e - } - } - def v_wrap_s (s:Statement) : Statement = { - s map (v_wrap_s) map (v_wrap_e) match { - case s: Print => - Print(s.info, VerilogStringLitHandler.format(s.string), s.args, s.clk, s.en) - case s => s + def name = "Verilog Wrap" + def vWrapE(e: Expression): Expression = e map vWrapE match { + case e: DoPrim => e.op match { + case Tail => e.args.head match { + case e0: DoPrim => e0.op match { + case Add => DoPrim(Addw, e0.args, Nil, e.tpe) + case Sub => DoPrim(Subw, e0.args, Nil, e.tpe) + case _ => e + } + case _ => e } - } - def run (c:Circuit): Circuit = { - val modulesx = c.modules.map{ m => { - (m) match { - case (m:Module) => { - mname = m.name - Module(m.info,m.name,m.ports,v_wrap_s(m.body)) - } - case (m:ExtModule) => m - } - }} - Circuit(c.info,modulesx,c.main) - } + case _ => e + } + case _ => e + } + def vWrapS(s: Statement): Statement = { + s map vWrapS map vWrapE match { + case s: Print => s copy (string = VerilogStringLitHandler.format(s.string)) + case s => s + } + } + + def run(c: Circuit): Circuit = + c copy (modules = (c.modules map (_ map vWrapS))) } object VerilogRename extends Pass { - def name = "Verilog Rename" - def run (c:Circuit): Circuit = { - def verilog_rename_n (n:String) : String = { - if (v_keywords.contains(n)) (n + "$") else n - } - def verilog_rename_e (e:Expression) : Expression = { - (e) match { - case (e:WRef) => WRef(verilog_rename_n(e.name),e.tpe,kind(e),gender(e)) - case (e) => e map (verilog_rename_e) - } - } - def verilog_rename_s (s:Statement) : Statement = { - s map (verilog_rename_s) map (verilog_rename_e) map (verilog_rename_n) - } - val modulesx = c.modules.map{ m => { - val portsx = m.ports.map{ p => { - Port(p.info,verilog_rename_n(p.name),p.direction,p.tpe) - }} - m match { - case (m:Module) => Module(m.info,m.name,portsx,verilog_rename_s(m.body)) - case (m:ExtModule) => m - } - }} - Circuit(c.info,modulesx,c.main) - } + def name = "Verilog Rename" + def verilogRenameN(n: String): String = + if (v_keywords(n)) "%s$".format(n) else n + + def verilogRenameE(e: Expression): Expression = e match { + case e: WRef => e copy (name = verilogRenameN(e.name)) + case e => e map verilogRenameE + } + + def verilogRenameS(s: Statement): Statement = + s map verilogRenameS map verilogRenameE map verilogRenameN + + def verilogRenameP(p: Port): Port = + p copy (name = verilogRenameN(p.name)) + + def run(c: Circuit): Circuit = + c copy (modules = (c.modules map (_ map verilogRenameP map verilogRenameS))) } |
