aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu2016-09-13 17:26:11 -0700
committerGitHub2016-09-13 17:26:11 -0700
commit1bb9597a01e77d9a1ece479caf13cf6c3f6229d5 (patch)
tree39b3dc1da954faea65777eb595e64fbd2b1a2f45 /src
parent96340374f091d5258ca69ef7fc614910e1c2cbb7 (diff)
parent1cfda487ec6773a139587c1c0bcf145c03b46800 (diff)
Merge pull request #285 from ucb-bar/more_passes_cleanups
More passes cleanups
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/CheckChirrtl.scala21
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala32
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala15
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala292
-rw-r--r--src/main/scala/firrtl/passes/MemUtils.scala2
-rw-r--r--src/main/scala/firrtl/passes/PadWidths.scala130
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala163
7 files changed, 276 insertions, 379 deletions
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala
index 2ab8749b..f21449a2 100644
--- a/src/main/scala/firrtl/passes/CheckChirrtl.scala
+++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala
@@ -27,21 +27,14 @@ MODIFICATIONS.
package firrtl.passes
-import com.typesafe.scalalogging.LazyLogging
-
-// Datastructures
-import scala.collection.mutable.HashSet
-
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
-import firrtl.PrimOps._
-import firrtl.WrappedType._
-
-object CheckChirrtl extends Pass with LazyLogging {
+object CheckChirrtl extends Pass {
def name = "Chirrtl Check"
+ type NameSet = collection.mutable.HashSet[String]
class NotUniqueException(info: Info, mname: String, name: String) extends PassException(
s"${info}: [module ${mname}] Reference ${name} does not have a unique name.")
@@ -101,7 +94,7 @@ object CheckChirrtl extends Pass with LazyLogging {
e
}
- def checkChirrtlE(info: Info, mname: String, names: HashSet[String])(e: Expression): Expression = {
+ def checkChirrtlE(info: Info, mname: String, names: NameSet)(e: Expression): Expression = {
e match {
case _: DoPrim | _:Mux | _:ValidIf | _: UIntLiteral =>
case e: Reference if !names(e.name) =>
@@ -114,14 +107,14 @@ object CheckChirrtl extends Pass with LazyLogging {
map checkChirrtlE(info, mname, names))
}
- def checkName(info: Info, mname: String, names: HashSet[String])(name: String): String = {
+ def checkName(info: Info, mname: String, names: NameSet)(name: String): String = {
if (names(name))
errors append (new NotUniqueException(info, mname, name))
names += name
name
}
- def checkChirrtlS(minfo: Info, mname: String, names: HashSet[String])(s: Statement): Statement = {
+ def checkChirrtlS(minfo: Info, mname: String, names: NameSet)(s: Statement): Statement = {
val info = get_info(s) match {case NoInfo => minfo case x => x}
(s map checkName(info, mname, names)) match {
case s: DefMemory =>
@@ -138,7 +131,7 @@ object CheckChirrtl extends Pass with LazyLogging {
map checkChirrtlS(info, mname, names))
}
- def checkChirrtlP(mname: String, names: HashSet[String])(p: Port): Port = {
+ def checkChirrtlP(mname: String, names: NameSet)(p: Port): Port = {
names += p.name
(p.tpe map checkChirrtlT(p.info, mname)
map checkChirrtlW(p.info, mname))
@@ -146,7 +139,7 @@ object CheckChirrtl extends Pass with LazyLogging {
}
def checkChirrtlM(m: DefModule) {
- val names = HashSet[String]()
+ val names = new NameSet
(m map checkChirrtlP(m.name, names)
map checkChirrtlS(m.info, m.name, names))
}
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 16b16ff7..c300f7c7 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -27,8 +27,6 @@ MODIFICATIONS.
package firrtl.passes
-import com.typesafe.scalalogging.LazyLogging
-
import firrtl._
import firrtl.ir._
import firrtl.PrimOps._
@@ -36,11 +34,9 @@ import firrtl.Utils._
import firrtl.Mappers._
import firrtl.WrappedType._
-// Datastructures
-import scala.collection.mutable.{HashMap, HashSet}
-
-object CheckHighForm extends Pass with LazyLogging {
+object CheckHighForm extends Pass {
def name = "High Form Check"
+ type NameSet = collection.mutable.HashSet[String]
// Custom Exceptions
class NotUniqueException(info: Info, mname: String, name: String) extends PassException(
@@ -160,7 +156,7 @@ object CheckHighForm extends Pass with LazyLogging {
e
}
- def checkHighFormE(info: Info, mname: String, names: HashSet[String])(e: Expression): Expression = {
+ def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Expression = {
e match {
case e: WRef if !names(e.name) =>
errors append new UndeclaredReferenceException(info, mname, e.name)
@@ -176,14 +172,14 @@ object CheckHighForm extends Pass with LazyLogging {
map checkHighFormE(info, mname, names))
}
- def checkName(info: Info, mname: String, names: HashSet[String])(name: String): String = {
+ def checkName(info: Info, mname: String, names: NameSet)(name: String): String = {
if (names(name))
errors append new NotUniqueException(info, mname, name)
names += name
name
}
- def checkHighFormS(minfo: Info, mname: String, names: HashSet[String])(s: Statement): Statement = {
+ def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Statement = {
val info = get_info(s) match {case NoInfo => minfo case x => x}
(s map checkName(info, mname, names)) match {
case s: DefMemory =>
@@ -208,7 +204,7 @@ object CheckHighForm extends Pass with LazyLogging {
map checkHighFormS(minfo, mname, names))
}
- def checkHighFormP(mname: String, names: HashSet[String])(p: Port): Port = {
+ def checkHighFormP(mname: String, names: NameSet)(p: Port): Port = {
names += p.name
(p.tpe map checkHighFormT(p.info, mname)
map checkHighFormW(p.info, mname))
@@ -216,7 +212,7 @@ object CheckHighForm extends Pass with LazyLogging {
}
def checkHighFormM(m: DefModule) {
- val names = HashSet[String]()
+ val names = new NameSet
(m map checkHighFormP(m.name, names)
map checkHighFormS(m.info, m.name, names))
}
@@ -231,7 +227,7 @@ object CheckHighForm extends Pass with LazyLogging {
}
}
-object CheckTypes extends Pass with LazyLogging {
+object CheckTypes extends Pass {
def name = "Check Types"
// Custom Exceptions
@@ -430,6 +426,7 @@ object CheckTypes extends Pass with LazyLogging {
object CheckGenders extends Pass {
def name = "Check Genders"
+ type GenderMap = collection.mutable.HashMap[String, Gender]
implicit def toStr(g: Gender): String = g match {
case MALE => "source"
@@ -444,7 +441,7 @@ object CheckGenders extends Pass {
def run (c:Circuit): Circuit = {
val errors = new Errors()
- def get_gender(e: Expression, genders: HashMap[String, Gender]): Gender = e match {
+ def get_gender(e: Expression, genders: GenderMap): Gender = e match {
case (e: WRef) => genders(e.name)
case (e: WSubIndex) => get_gender(e.exp, genders)
case (e: WSubAccess) => get_gender(e.exp, genders)
@@ -466,8 +463,7 @@ object CheckGenders extends Pass {
flip_rec(t, Default)
}
- def check_gender(info:Info, mname: String,
- genders: HashMap[String,Gender], desired: Gender)(e:Expression): Expression = {
+ def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Expression = {
val gender = get_gender(e,genders)
(gender, desired) match {
case (MALE, FEMALE) =>
@@ -482,7 +478,7 @@ object CheckGenders extends Pass {
e
}
- def check_genders_e (info:Info, mname: String, genders: HashMap[String,Gender])(e:Expression): Expression = {
+ def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Expression = {
e match {
case e: Mux => e map check_gender(info, mname, genders, MALE)
case e: DoPrim => e.args map check_gender(info, mname, genders, MALE)
@@ -491,7 +487,7 @@ object CheckGenders extends Pass {
e map check_genders_e(info, mname, genders)
}
- def check_genders_s(minfo: Info, mname: String, genders: HashMap[String,Gender])(s: Statement): Statement = {
+ def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Statement = {
val info = get_info(s) match { case NoInfo => minfo case x => x }
s match {
case (s: DefWire) => genders(s.name) = BIGENDER
@@ -522,7 +518,7 @@ object CheckGenders extends Pass {
}
for (m <- c.modules) {
- val genders = HashMap[String, Gender]()
+ val genders = new GenderMap
genders ++= (m.ports map (p => p.name -> to_gender(p.direction)))
m map check_genders_s(m.info, m.name, genders)
}
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 6b2ff6ed..ebec4d80 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -28,20 +28,19 @@ MODIFICATIONS.
package firrtl.passes
// Datastructures
-import scala.collection.mutable.{LinkedHashMap, HashMap, HashSet, ArrayBuffer}
+import scala.collection.mutable.ArrayBuffer
import scala.collection.immutable.ListMap
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
-import firrtl.PrimOps._
-import firrtl.WrappedExpression._
object InferWidths extends Pass {
def name = "Infer Widths"
+ type ConstraintMap = collection.mutable.LinkedHashMap[String, Width]
- def solve_constraints(l: Seq[WGeq]): LinkedHashMap[String, Width] = {
+ def solve_constraints(l: Seq[WGeq]): ConstraintMap = {
def unique(ls: Seq[Width]) : Seq[Width] =
(ls map (new WrappedWidth(_))).distinct map (_.w)
def make_unique(ls: Seq[WGeq]): ListMap[String,Width] = {
@@ -77,7 +76,7 @@ object InferWidths extends Pass {
case _ => w
}
- def substitute(h: LinkedHashMap[String, Width])(w: Width): Width = {
+ def substitute(h: ConstraintMap)(w: Width): Width = {
//;println-all-debug(["Substituting for [" w "]"])
val wx = simplify(w)
//;println-all-debug(["After Simplify: [" wx "]"])
@@ -98,7 +97,7 @@ object InferWidths extends Pass {
}
}
- def b_sub(h: LinkedHashMap[String, Width])(w: Width): Width = {
+ def b_sub(h: ConstraintMap)(w: Width): Width = {
w map b_sub(h) match {
case w: VarWidth => h getOrElse (w.name, w)
case w => w
@@ -145,7 +144,7 @@ object InferWidths extends Pass {
//for (x <- u) { println(x) }
//println("====================================")
- val f = LinkedHashMap[String, Width]()
+ val f = new ConstraintMap
val o = ArrayBuffer[String]()
for ((n, e) <- u) {
//println("==== SOLUTIONS TABLE ====")
@@ -175,7 +174,7 @@ object InferWidths extends Pass {
//for (x <- f) println(x)
//; Backwards Solve
- val b = LinkedHashMap[String, Width]()
+ val b = new ConstraintMap
for (i <- (o.size - 1) to 0 by -1) {
val n = o(i) // Should visit `o` backward
/*
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index 57f8fd76..b3969bea 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -27,16 +27,11 @@ MODIFICATIONS.
package firrtl.passes
-import com.typesafe.scalalogging.LazyLogging
-
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
-// Datastructures
-import scala.collection.mutable.HashMap
-
/** Removes all aggregate types from a [[firrtl.ir.Circuit]]
*
* @note Assumes [[firrtl.ir.SubAccess]]es have been removed
@@ -67,8 +62,8 @@ object LowerTypes extends Pass {
def loweredName(s: Seq[String]): String = s mkString delim
private case class LowerTypesException(msg: String) extends FIRRTLException(msg)
- private def error(msg: String)(implicit sinfo: Info, mname: String) =
- throw new LowerTypesException(s"$sinfo: [module $mname] $msg")
+ private def error(msg: String)(info: Info, mname: String) =
+ throw LowerTypesException(s"$info: [module $mname] $msg")
// TODO Improve? Probably not the best way to do this
private def splitMemRef(e1: Expression): (WRef, WRef, WRef, Option[Expression]) = {
@@ -83,165 +78,146 @@ object LowerTypes extends Pass {
}
}
- // Everything wrapped in run so that it's thread safe
- def run(c: Circuit): Circuit = {
- // Debug state
- implicit var mname: String = ""
- implicit var sinfo: Info = NoInfo
-
- def lowerTypes(m: DefModule): DefModule = {
- val memDataTypeMap = HashMap[String, Type]()
-
- // Lowers an expression of MemKind
- // Since mems with Bundle type must be split into multiple ground type
- // mem, references to fields addr, en, clk, and rmode must be replicated
- // for each resulting memory
- // References to data, mask, rdata, wdata, and wmask have already been split in expand connects
- // and just need to be converted to refer to the correct new memory
- def lowerTypesMemExp(e: Expression): Seq[Expression] = {
- val (mem, port, field, tail) = splitMemRef(e)
- field.name match {
- // Fields that need to be replicated for each resulting mem
- case "addr" | "en" | "clk" | "wmode" =>
- require(tail.isEmpty) // there can't be a tail for these
- memDataTypeMap(mem.name) match {
- case _: GroundType => Seq(e)
- case memType => create_exps(mem.name, memType) map { e =>
- val loMemName = loweredName(e)
- val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
- mergeRef(loMem, mergeRef(port, field))
- }
- }
- // Fields that need not be replicated for each
- // eg. mem.reader.data[0].a
- // (Connect/IsInvalid must already have been split to ground types)
- case "data" | "mask" | "rdata" | "wdata" | "wmask" =>
- val loMem = tail match {
- case Some(e) =>
- val loMemExp = mergeRef(mem, e)
- val loMemName = loweredName(loMemExp)
- WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
- case None => mem
- }
- Seq(mergeRef(loMem, mergeRef(port, field)))
- case name => error(s"Error! Unhandled memory field ${name}")
+ // Lowers an expression of MemKind
+ // Since mems with Bundle type must be split into multiple ground type
+ // mem, references to fields addr, en, clk, and rmode must be replicated
+ // for each resulting memory
+ // References to data, mask, rdata, wdata, and wmask have already been split in expand connects
+ // and just need to be converted to refer to the correct new memory
+ type MemDataTypeMap = collection.mutable.HashMap[String, Type]
+ def lowerTypesMemExp(memDataTypeMap: MemDataTypeMap,
+ info: Info, mname: String)(e: Expression): Seq[Expression] = {
+ val (mem, port, field, tail) = splitMemRef(e)
+ field.name match {
+ // Fields that need to be replicated for each resulting mem
+ case "addr" | "en" | "clk" | "wmode" =>
+ require(tail.isEmpty) // there can't be a tail for these
+ memDataTypeMap(mem.name) match {
+ case _: GroundType => Seq(e)
+ case memType => create_exps(mem.name, memType) map { e =>
+ val loMemName = loweredName(e)
+ val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
+ mergeRef(loMem, mergeRef(port, field))
+ }
}
- }
+ // Fields that need not be replicated for each
+ // eg. mem.reader.data[0].a
+ // (Connect/IsInvalid must already have been split to ground types)
+ case "data" | "mask" | "rdata" | "wdata" | "wmask" =>
+ val loMem = tail match {
+ case Some(e) =>
+ val loMemExp = mergeRef(mem, e)
+ val loMemName = loweredName(loMemExp)
+ WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
+ case None => mem
+ }
+ Seq(mergeRef(loMem, mergeRef(port, field)))
+ case name => error(s"Error! Unhandled memory field ${name}")(info, mname)
+ }
+ }
- def lowerTypesExp(e: Expression): Expression = e match {
- case e: WRef => e
- case (_: WSubField | _: WSubIndex) => kind(e) match {
- case k: InstanceKind =>
- val (root, tail) = splitRef(e)
- val name = loweredName(tail)
- WSubField(root, name, e.tpe, gender(e))
- case k: MemKind =>
- val exps = lowerTypesMemExp(e)
- exps.size match {
- case 1 => exps.head
- case _ => error("Error! lowerTypesExp called on MemKind " +
- "SubField that needs to be expanded!")
- }
- case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e))
+ def lowerTypesExp(memDataTypeMap: MemDataTypeMap,
+ info: Info, mname: String)(e: Expression): Expression = e match {
+ case e: WRef => e
+ case (_: WSubField | _: WSubIndex) => kind(e) match {
+ case k: InstanceKind =>
+ val (root, tail) = splitRef(e)
+ val name = loweredName(tail)
+ WSubField(root, name, e.tpe, gender(e))
+ case k: MemKind =>
+ val exps = lowerTypesMemExp(memDataTypeMap, info, mname)(e)
+ exps.size match {
+ case 1 => exps.head
+ case _ => error("Error! lowerTypesExp called on MemKind " +
+ "SubField that needs to be expanded!")(info, mname)
}
- case e: Mux => e map (lowerTypesExp)
- case e: ValidIf => e map (lowerTypesExp)
- case e: DoPrim => e map (lowerTypesExp)
- case e @ (_: UIntLiteral | _: SIntLiteral) => e
- }
+ case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e))
+ }
+ case e: Mux => e map lowerTypesExp(memDataTypeMap, info, mname)
+ case e: ValidIf => e map lowerTypesExp(memDataTypeMap, info, mname)
+ case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname)
+ case e @ (_: UIntLiteral | _: SIntLiteral) => e
+ }
- def lowerTypesStmt(s: Statement): Statement = s map lowerTypesStmt match {
- case s: DefWire =>
- sinfo = s.info
- s.tpe match {
- case _: GroundType => s
- case _ => Block(create_exps(s.name, s.tpe) map (
- e => DefWire(s.info, loweredName(e), e.tpe)))
- }
- case s: DefRegister =>
- sinfo = s.info
- s.tpe match {
- case _: GroundType => s map lowerTypesExp
- case _ =>
- val es = create_exps(s.name, s.tpe)
- val inits = create_exps(s.init) map (lowerTypesExp)
- val clock = lowerTypesExp(s.clock)
- val reset = lowerTypesExp(s.reset)
- Block(es zip inits map { case (e, i) =>
- DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i)
- })
- }
- // Could instead just save the type of each Module as it gets processed
- case s: WDefInstance =>
- sinfo = s.info
- s.tpe match {
- case t: BundleType =>
- val fieldsx = t.fields flatMap (f =>
- create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map (
- // Flip because inst genders are reversed from Module type
- e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)
- )
- )
- WDefInstance(s.info, s.name, s.module, BundleType(fieldsx))
- case _ => error("WDefInstance type should be Bundle!")
- }
- case s: DefMemory =>
- sinfo = s.info
- memDataTypeMap(s.name) = s.dataType
- s.dataType match {
- case _: GroundType => s
- case _ => Block(create_exps(s.name, s.dataType) map (e =>
- DefMemory(s.info, loweredName(e), e.tpe, s.depth,
- s.writeLatency, s.readLatency, s.readers, s.writers,
- s.readwriters)))
- }
- // wire foo : { a , b }
- // node x = foo
- // node y = x.a
- // ->
- // node x_a = foo_a
- // node x_b = foo_b
- // node y = x_a
- case s: DefNode =>
- sinfo = s.info
- val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp)
- val exps = create_exps(s.value) map (lowerTypesExp)
- Block(names zip exps map {case (n, e) => DefNode(s.info, loweredName(n), e)})
- case s: IsInvalid =>
- sinfo = s.info
- kind(s.expr) match {
- case k: MemKind =>
- Block(lowerTypesMemExp(s.expr) map (IsInvalid(s.info, _)))
- case _ => s map (lowerTypesExp)
- }
- case s: Connect =>
- sinfo = s.info
- kind(s.loc) match {
- case k: MemKind =>
- val exp = lowerTypesExp(s.expr)
- val locs = lowerTypesMemExp(s.loc)
- Block(locs map (Connect(s.info, _, exp)))
- case _ => s map (lowerTypesExp)
- }
- case s => s map (lowerTypesExp)
+ def lowerTypesStmt(memDataTypeMap: MemDataTypeMap,
+ minfo: Info, mname: String)(s: Statement): Statement = {
+ val info = get_info(s) match {case NoInfo => minfo case x => x}
+ s map lowerTypesStmt(memDataTypeMap, info, mname) match {
+ case s: DefWire => s.tpe match {
+ case _: GroundType => s
+ case _ => Block(create_exps(s.name, s.tpe) map (
+ e => DefWire(s.info, loweredName(e), e.tpe)))
}
-
- sinfo = m.info
- mname = m.name
- // Lower Ports
- val portsx = m.ports flatMap ( p =>
- create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) map (
- e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe)
- )
- )
- m match {
- case m: ExtModule => m.copy(ports = portsx)
- case m: Module => Module(m.info, m.name, portsx, lowerTypesStmt(m.body))
+ case s: DefRegister => s.tpe match {
+ case _: GroundType => s map lowerTypesExp(memDataTypeMap, info, mname)
+ case _ =>
+ val es = create_exps(s.name, s.tpe)
+ val inits = create_exps(s.init) map lowerTypesExp(memDataTypeMap, info, mname)
+ val clock = lowerTypesExp(memDataTypeMap, info, mname)(s.clock)
+ val reset = lowerTypesExp(memDataTypeMap, info, mname)(s.reset)
+ Block(es zip inits map { case (e, i) =>
+ DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i)
+ })
+ }
+ // Could instead just save the type of each Module as it gets processed
+ case s: WDefInstance => s.tpe match {
+ case t: BundleType =>
+ val fieldsx = t.fields flatMap (f =>
+ create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map (
+ // Flip because inst genders are reversed from Module type
+ e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)))
+ WDefInstance(s.info, s.name, s.module, BundleType(fieldsx))
+ case _ => error("WDefInstance type should be Bundle!")(info, mname)
}
+ case s: DefMemory =>
+ memDataTypeMap(s.name) = s.dataType
+ s.dataType match {
+ case _: GroundType => s
+ case _ => Block(create_exps(s.name, s.dataType) map (e =>
+ s copy (name = loweredName(e), dataType = e.tpe)))
+ }
+ // wire foo : { a , b }
+ // node x = foo
+ // node y = x.a
+ // ->
+ // node x_a = foo_a
+ // node x_b = foo_b
+ // node y = x_a
+ case s: DefNode =>
+ val names = create_exps(s.name, s.value.tpe) map lowerTypesExp(memDataTypeMap, info, mname)
+ val exps = create_exps(s.value) map lowerTypesExp(memDataTypeMap, info, mname)
+ Block(names zip exps map { case (n, e) => DefNode(info, loweredName(n), e) })
+ case s: IsInvalid => kind(s.expr) match {
+ case _: MemKind =>
+ Block(lowerTypesMemExp(memDataTypeMap, info, mname)(s.expr) map (IsInvalid(info, _)))
+ case _ => s map lowerTypesExp(memDataTypeMap, info, mname)
+ }
+ case s: Connect => kind(s.loc) match {
+ case k: MemKind =>
+ val exp = lowerTypesExp(memDataTypeMap, info, mname)(s.expr)
+ val locs = lowerTypesMemExp(memDataTypeMap, info, mname)(s.loc)
+ Block(locs map (Connect(info, _, exp)))
+ case _ => s map lowerTypesExp(memDataTypeMap, info, mname)
+ }
+ case s => s map lowerTypesExp(memDataTypeMap, info, mname)
}
+ }
- sinfo = c.info
- Circuit(c.info, c.modules map lowerTypes, c.main)
+ def lowerTypes(m: DefModule): DefModule = {
+ val memDataTypeMap = new MemDataTypeMap
+ // Lower Ports
+ val portsx = m.ports flatMap { p =>
+ val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction)))
+ exps map (e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe))
+ }
+ m match {
+ case m: ExtModule =>
+ m copy (ports = portsx)
+ case m: Module =>
+ m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name)
+ }
}
+
+ def run(c: Circuit): Circuit = c copy (modules = (c.modules map lowerTypes))
}
diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala
index 87033176..57a7120b 100644
--- a/src/main/scala/firrtl/passes/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/MemUtils.scala
@@ -27,8 +27,6 @@
package firrtl.passes
-import com.typesafe.scalalogging.LazyLogging
-
import firrtl._
import firrtl.ir._
import firrtl.Utils._
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala
index bef9ac33..4c198bab 100644
--- a/src/main/scala/firrtl/passes/PadWidths.scala
+++ b/src/main/scala/firrtl/passes/PadWidths.scala
@@ -7,80 +7,58 @@ import firrtl.Mappers._
// Makes all implicit width extensions and truncations explicit
object PadWidths extends Pass {
- def name = "Pad Widths"
- private def width(t: Type): Int = bitWidth(t).toInt
- private def width(e: Expression): Int = width(e.tpe)
- // Returns an expression with the correct integer width
- private def fixup(i: Int)(e: Expression) = {
- def tx = e.tpe match {
- case t: UIntType => UIntType(IntWidth(i))
- case t: SIntType => SIntType(IntWidth(i))
- // default case should never be reached
- }
- if (i > width(e)) {
- DoPrim(Pad, Seq(e), Seq(i), tx)
- } else if (i < width(e)) {
- val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i)))
- // Bit Select always returns UInt, cast if selecting from SInt
- e.tpe match {
- case UIntType(_) => e2
- case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i)))
- }
- } else {
- e
- }
- }
- // Recursive, updates expression so children exp's have correct widths
- private def onExp(e: Expression): Expression = {
- val sensitiveOps = Seq( Lt, Leq, Gt, Geq, Eq, Neq, Not, And, Or, Xor,
- Add, Sub, Mul, Div, Rem, Shr)
- val x = e map onExp
- x match {
- case Mux(cond, tval, fval, tpe) => {
- val tvalx = fixup(width(tpe))(tval)
- val fvalx = fixup(width(tpe))(fval)
- Mux(cond, tvalx, fvalx, tpe)
- }
- case DoPrim(op, args, consts, tpe) => op match {
- case _ if sensitiveOps.contains(op) => {
- val i = args.map(a => width(a)).foldLeft(0) {(a, b) => math.max(a, b)}
- x map fixup(i)
- }
- case Dshl => {
- // special case as args aren't all same width
- val ax = fixup(width(tpe))(args(0))
- DoPrim(Dshlw, Seq(ax, args(1)), consts, tpe)
- }
- case Shl => {
- // special case as arg should be same width as result
- val ax = fixup(width(tpe))(args(0))
- DoPrim(Shlw, Seq(ax), consts, tpe)
- }
- case _ => x
- }
- case ValidIf(cond, value, tpe) => ValidIf(cond, fixup(width(tpe))(value), tpe)
- case x => x
- }
- }
- // Recursive. Fixes assignments and register initialization widths
- private def onStmt(s: Statement): Statement = {
- s map onExp match {
- case s: Connect => {
- val ex = fixup(width(s.loc))(s.expr)
- Connect(s.info, s.loc, ex)
- }
- case s: DefRegister => {
- val ex = fixup(width(s.tpe))(s.init)
- DefRegister(s.info, s.name, s.tpe, s.clock, s.reset, ex)
- }
- case s => s map onStmt
- }
- }
- private def onModule(m: DefModule): DefModule = {
- m match {
- case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body))
- case m: ExtModule => m
- }
- }
- def run(c: Circuit): Circuit = Circuit(c.info, c.modules.map(onModule _), c.main)
+ def name = "Pad Widths"
+ private def width(t: Type): Int = bitWidth(t).toInt
+ private def width(e: Expression): Int = width(e.tpe)
+ // Returns an expression with the correct integer width
+ private def fixup(i: Int)(e: Expression) = {
+ def tx = e.tpe match {
+ case t: UIntType => UIntType(IntWidth(i))
+ case t: SIntType => SIntType(IntWidth(i))
+ // default case should never be reached
+ }
+ width(e) match {
+ case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx)
+ case j if i < j =>
+ val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i)))
+ // Bit Select always returns UInt, cast if selecting from SInt
+ e.tpe match {
+ case UIntType(_) => e2
+ case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i)))
+ }
+ case _ => e
+ }
+ }
+
+ // Recursive, updates expression so children exp's have correct widths
+ private def onExp(e: Expression): Expression = e map onExp match {
+ case Mux(cond, tval, fval, tpe) =>
+ Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe)
+ case e: ValidIf => e copy (value = fixup(width(e.tpe))(e.value))
+ case e: DoPrim => e.op match {
+ case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor |
+ Add | Sub | Mul | Div | Rem | Shr =>
+ // sensitive ops
+ e map fixup((e.args map (width(_)) foldLeft 0)(math.max(_, _)))
+ case Dshl =>
+ // special case as args aren't all same width
+ e copy (op = Dshlw, args = Seq(fixup(width(e.tpe))(e.args(0)), e.args(1)))
+ case Shl =>
+ // special case as arg should be same width as result
+ e copy (op = Shlw, args = Seq(fixup(width(e.tpe))(e.args(0))))
+ case _ => e
+ }
+ case e => e
+ }
+
+ // Recursive. Fixes assignments and register initialization widths
+ private def onStmt(s: Statement): Statement = s map onExp match {
+ case s: Connect =>
+ s copy (expr = fixup(width(s.loc))(s.expr))
+ case s: DefRegister =>
+ s copy (init = fixup(width(s.tpe))(s.init))
+ case s => s map onStmt
+ }
+
+ def run(c: Circuit): Circuit = c copy (modules = (c.modules map (_ map onStmt)))
}
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index d5d9a3b6..965ae339 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -28,20 +28,12 @@ MODIFICATIONS.
package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
-import java.nio.file.{Paths, Files}
-
-// Datastructures
-import scala.collection.mutable.LinkedHashMap
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.ArrayBuffer
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.PrimOps._
-import firrtl.WrappedExpression._
trait Pass extends LazyLogging {
def name: String
@@ -52,7 +44,7 @@ trait Pass extends LazyLogging {
class PassException(message: String) extends Exception(message)
class PassExceptions(exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n"))
class Errors {
- val errors = ArrayBuffer[PassException]()
+ val errors = collection.mutable.ArrayBuffer[PassException]()
def append(pe: PassException) = errors.append(pe)
def trigger = errors.size match {
case 0 =>
@@ -65,33 +57,23 @@ class Errors {
// These should be distributed into separate files
object ToWorkingIR extends Pass {
- private var mname = ""
- def name = "Working IR"
- def run (c:Circuit): Circuit = {
- def toExp (e:Expression) : Expression = {
- e map (toExp) match {
- case e:Reference => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER)
- case e:SubField => WSubField(e.expr, e.name, e.tpe, UNKNOWNGENDER)
- case e:SubIndex => WSubIndex(e.expr, e.value, e.tpe, UNKNOWNGENDER)
- case e:SubAccess => WSubAccess(e.expr, e.index, e.tpe, UNKNOWNGENDER)
- case e => e
- }
- }
- def toStmt (s:Statement) : Statement = {
- s map (toExp) match {
- case s:DefInstance => WDefInstance(s.info,s.name,s.module,UnknownType)
- case s => s map (toStmt)
- }
- }
- val modulesx = c.modules.map { m =>
- mname = m.name
- m match {
- case m:Module => Module(m.info,m.name, m.ports, toStmt(m.body))
- case m:ExtModule => m
- }
- }
- Circuit(c.info,modulesx,c.main)
- }
+ def name = "Working IR"
+
+ def toExp(e:Expression) : Expression = e map (toExp) match {
+ case e: Reference => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER)
+ case e: SubField => WSubField(e.expr, e.name, e.tpe, UNKNOWNGENDER)
+ case e: SubIndex => WSubIndex(e.expr, e.value, e.tpe, UNKNOWNGENDER)
+ case e: SubAccess => WSubAccess(e.expr, e.index, e.tpe, UNKNOWNGENDER)
+ case e => e
+ }
+
+ def toStmt(s: Statement): Statement = s map (toExp) match {
+ case s: DefInstance => WDefInstance(s.info, s.name, s.module, UnknownType)
+ case s => s map (toStmt)
+ }
+
+ def run (c:Circuit): Circuit =
+ c copy (modules = (c.modules map (_ map toStmt)))
}
object PullMuxes extends Pass {
@@ -140,7 +122,7 @@ object ExpandConnects extends Pass {
def name = "Expand Connects"
def run(c: Circuit): Circuit = {
def expand_connects(m: Module): Module = {
- val genders = LinkedHashMap[String,Gender]()
+ val genders = collection.mutable.LinkedHashMap[String,Gender]()
def expand_s(s: Statement): Statement = {
def set_gender(e: Expression): Expression = e map (set_gender) match {
case (e: WRef) => WRef(e.name, e.tpe, e.kind, genders(e.name))
@@ -276,78 +258,53 @@ object Legalize extends Pass {
}
legalizedStmt map legalizeS map legalizeE
}
- def legalizeM (m: DefModule): DefModule = m map (legalizeS)
- Circuit(c.info, c.modules.map(legalizeM), c.main)
+ c copy (modules = (c.modules map (_ map legalizeS)))
}
}
object VerilogWrap extends Pass {
- def name = "Verilog Wrap"
- var mname = ""
- def v_wrap_e (e:Expression) : Expression = {
- e map (v_wrap_e) match {
- case (e:DoPrim) => {
- def a0 () = e.args(0)
- if (e.op == Tail) {
- (a0()) match {
- case (e0:DoPrim) => {
- if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),e.tpe)
- else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),e.tpe)
- else e
- }
- case (e0) => e
- }
- }
- else e
- }
- case (e) => e
- }
- }
- def v_wrap_s (s:Statement) : Statement = {
- s map (v_wrap_s) map (v_wrap_e) match {
- case s: Print =>
- Print(s.info, VerilogStringLitHandler.format(s.string), s.args, s.clk, s.en)
- case s => s
+ def name = "Verilog Wrap"
+ def vWrapE(e: Expression): Expression = e map vWrapE match {
+ case e: DoPrim => e.op match {
+ case Tail => e.args.head match {
+ case e0: DoPrim => e0.op match {
+ case Add => DoPrim(Addw, e0.args, Nil, e.tpe)
+ case Sub => DoPrim(Subw, e0.args, Nil, e.tpe)
+ case _ => e
+ }
+ case _ => e
}
- }
- def run (c:Circuit): Circuit = {
- val modulesx = c.modules.map{ m => {
- (m) match {
- case (m:Module) => {
- mname = m.name
- Module(m.info,m.name,m.ports,v_wrap_s(m.body))
- }
- case (m:ExtModule) => m
- }
- }}
- Circuit(c.info,modulesx,c.main)
- }
+ case _ => e
+ }
+ case _ => e
+ }
+ def vWrapS(s: Statement): Statement = {
+ s map vWrapS map vWrapE match {
+ case s: Print => s copy (string = VerilogStringLitHandler.format(s.string))
+ case s => s
+ }
+ }
+
+ def run(c: Circuit): Circuit =
+ c copy (modules = (c.modules map (_ map vWrapS)))
}
object VerilogRename extends Pass {
- def name = "Verilog Rename"
- def run (c:Circuit): Circuit = {
- def verilog_rename_n (n:String) : String = {
- if (v_keywords.contains(n)) (n + "$") else n
- }
- def verilog_rename_e (e:Expression) : Expression = {
- (e) match {
- case (e:WRef) => WRef(verilog_rename_n(e.name),e.tpe,kind(e),gender(e))
- case (e) => e map (verilog_rename_e)
- }
- }
- def verilog_rename_s (s:Statement) : Statement = {
- s map (verilog_rename_s) map (verilog_rename_e) map (verilog_rename_n)
- }
- val modulesx = c.modules.map{ m => {
- val portsx = m.ports.map{ p => {
- Port(p.info,verilog_rename_n(p.name),p.direction,p.tpe)
- }}
- m match {
- case (m:Module) => Module(m.info,m.name,portsx,verilog_rename_s(m.body))
- case (m:ExtModule) => m
- }
- }}
- Circuit(c.info,modulesx,c.main)
- }
+ def name = "Verilog Rename"
+ def verilogRenameN(n: String): String =
+ if (v_keywords(n)) "%s$".format(n) else n
+
+ def verilogRenameE(e: Expression): Expression = e match {
+ case e: WRef => e copy (name = verilogRenameN(e.name))
+ case e => e map verilogRenameE
+ }
+
+ def verilogRenameS(s: Statement): Statement =
+ s map verilogRenameS map verilogRenameE map verilogRenameN
+
+ def verilogRenameP(p: Port): Port =
+ p copy (name = verilogRenameN(p.name))
+
+ def run(c: Circuit): Circuit =
+ c copy (modules = (c.modules map (_ map verilogRenameP map verilogRenameS)))
}