aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-09-07 13:14:22 -0700
committerGitHub2016-09-07 13:14:22 -0700
commit8647a25fec8c5e18d766ff3e3602d3345cd8549c (patch)
tree429f7acf1f95b0c1e3e9b9b1f2d528c49761356b /src/main/scala/firrtl/passes
parent0c6db9ef0669e3fb92fcc0bda2085f934d065f0b (diff)
parentb1b977407d12878fb5d8ea92950888002beb258b (diff)
Merge pull request #271 from ucb-bar/cleanup_utils
Clean up Utils
Diffstat (limited to 'src/main/scala/firrtl/passes')
-rw-r--r--src/main/scala/firrtl/passes/CheckChirrtl.scala4
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala71
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala10
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala4
-rw-r--r--src/main/scala/firrtl/passes/Inline.scala1
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala82
-rw-r--r--src/main/scala/firrtl/passes/PadWidths.scala6
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala72
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala6
-rw-r--r--src/main/scala/firrtl/passes/SplitExpressions.scala14
-rw-r--r--src/main/scala/firrtl/passes/Uniquify.scala11
11 files changed, 148 insertions, 133 deletions
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala
index 60a49bac..e0e7c57a 100644
--- a/src/main/scala/firrtl/passes/CheckChirrtl.scala
+++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala
@@ -105,7 +105,7 @@ object CheckChirrtl extends Pass with LazyLogging {
e
}
def checkChirrtlS(s: Statement): Statement = {
- sinfo = s.getInfo
+ sinfo = get_info(s)
def checkName(name: String): String = {
if (names.contains(name)) errors.append(new NotUniqueException(name))
else names(name) = true
@@ -138,7 +138,7 @@ object CheckChirrtl extends Pass with LazyLogging {
for (p <- m.ports) {
sinfo = p.info
names(p.name) = true
- val tpe = p.getType
+ val tpe = p.tpe
tpe map (checkChirrtlT)
tpe map (checkChirrtlW)
}
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 9ee20c0a..6e49ce93 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -241,7 +241,7 @@ object CheckHighForm extends Pass with LazyLogging {
else names(name) = true
name
}
- sinfo = s.getInfo
+ sinfo = get_info(s)
s map (checkName)
s map (checkHighFormT)
@@ -276,7 +276,7 @@ object CheckHighForm extends Pass with LazyLogging {
for (p <- m.ports) {
// FIXME should we set sinfo here?
names(p.name) = true
- val tpe = p.getType
+ val tpe = p.tpe
tpe map (checkHighFormT)
tpe map (checkHighFormW)
}
@@ -336,27 +336,36 @@ object CheckTypes extends Pass with LazyLogging {
def all_same_type (ls:Seq[Expression]) : Unit = {
var error = false
for (x <- ls) {
- if (wt(tpe(ls.head)) != wt(tpe(x))) error = true
+ if (wt(ls.head.tpe) != wt(x.tpe)) error = true
}
if (error) errors.append(new OpNotAllSameType(info,e.op.serialize))
}
def all_ground (ls:Seq[Expression]) : Unit = {
var error = false
for (x <- ls ) {
- if (!(tpe(x).typeof[UIntType] || tpe(x).typeof[SIntType])) error = true
+ x.tpe match {
+ case _: UIntType | _: SIntType =>
+ case _ => error = true
+ }
}
if (error) errors.append(new OpNotGround(info,e.op.serialize))
}
def all_uint (ls:Seq[Expression]) : Unit = {
var error = false
for (x <- ls ) {
- if (!(tpe(x).typeof[UIntType])) error = true
+ x.tpe match {
+ case _: UIntType =>
+ case _ => error = true
+ }
}
if (error) errors.append(new OpNotAllUInt(info,e.op.serialize))
}
def is_uint (x:Expression) : Unit = {
var error = false
- if (!(tpe(x).typeof[UIntType])) error = true
+ x.tpe match {
+ case _: UIntType =>
+ case _ => error = true
+ }
if (error) errors.append(new OpNotUInt(info,e.op.serialize,x.serialize))
}
@@ -417,7 +426,7 @@ object CheckTypes extends Pass with LazyLogging {
(e map (check_types_e(info))) match {
case (e:WRef) => e
case (e:WSubField) => {
- (tpe(e.exp)) match {
+ (e.exp.tpe) match {
case (t:BundleType) => {
val ft = t.fields.find(p => p.name == e.name)
if (ft == None) errors.append(new SubfieldNotInBundle(info,e.name))
@@ -426,7 +435,7 @@ object CheckTypes extends Pass with LazyLogging {
}
}
case (e:WSubIndex) => {
- (tpe(e.exp)) match {
+ (e.exp.tpe) match {
case (t:VectorType) => {
if (e.value >= t.size) errors.append(new IndexTooLarge(info,e.value))
}
@@ -434,24 +443,30 @@ object CheckTypes extends Pass with LazyLogging {
}
}
case (e:WSubAccess) => {
- (tpe(e.exp)) match {
+ (e.exp.tpe) match {
case (t:VectorType) => false
case (t) => errors.append(new IndexOnNonVector(info))
}
- (tpe(e.index)) match {
+ (e.index.tpe) match {
case (t:UIntType) => false
case (t) => errors.append(new AccessIndexNotUInt(info))
}
}
case (e:DoPrim) => check_types_primop(e,errors,info)
case (e:Mux) => {
- if (wt(tpe(e.tval)) != wt(tpe(e.fval))) errors.append(new MuxSameType(info))
- if (!passive(tpe(e))) errors.append(new MuxPassiveTypes(info))
- if (!(tpe(e.cond).typeof[UIntType])) errors.append(new MuxCondUInt(info))
+ if (wt(e.tval.tpe) != wt(e.fval.tpe)) errors.append(new MuxSameType(info))
+ if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info))
+ e.cond.tpe match {
+ case _: UIntType =>
+ case _ => errors.append(new MuxCondUInt(info))
+ }
}
case (e:ValidIf) => {
- if (!passive(tpe(e))) errors.append(new ValidIfPassiveTypes(info))
- if (!(tpe(e.cond).typeof[UIntType])) errors.append(new ValidIfCondUInt(info))
+ if (!passive(e.tpe)) errors.append(new ValidIfPassiveTypes(info))
+ e.cond.tpe match {
+ case _: UIntType =>
+ case _ => errors.append(new ValidIfCondUInt(info))
+ }
}
case (_:UIntLiteral | _:SIntLiteral) => false
}
@@ -484,22 +499,22 @@ object CheckTypes extends Pass with LazyLogging {
def check_types_s (s:Statement) : Statement = {
s map (check_types_e(get_info(s))) match {
- case (s:Connect) => if (wt(tpe(s.loc)) != wt(tpe(s.expr))) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
- case (s:DefRegister) => if (wt(s.tpe) != wt(tpe(s.init))) errors.append(new InvalidRegInit(s.info))
- case (s:PartialConnect) => if (!bulk_equals(tpe(s.loc),tpe(s.expr),Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
+ case (s:Connect) => if (wt(s.loc.tpe) != wt(s.expr.tpe)) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
+ case (s:DefRegister) => if (wt(s.tpe) != wt(s.init.tpe)) errors.append(new InvalidRegInit(s.info))
+ case (s:PartialConnect) => if (!bulk_equals(s.loc.tpe,s.expr.tpe,Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
case (s:Stop) => {
- if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info))
- if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
+ if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info))
+ if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
}
case (s:Print)=> {
for (x <- s.args ) {
- if (wt(tpe(x)) != wt(ut()) && wt(tpe(x)) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info))
+ if (wt(x.tpe) != wt(ut()) && wt(x.tpe) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info))
}
- if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info))
- if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
+ if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info))
+ if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
}
- case (s:Conditionally) => if (wt(tpe(s.pred)) != wt(ut()) ) errors.append(new PredNotUInt(s.info))
- case (s:DefNode) => if (!passive(tpe(s.value)) ) errors.append(new NodePassiveType(s.info))
+ case (s:Conditionally) => if (wt(s.pred.tpe) != wt(ut()) ) errors.append(new PredNotUInt(s.info))
+ case (s:DefNode) => if (!passive(s.value.tpe) ) errors.append(new NodePassiveType(s.info))
case (s) => false
}
s map (check_types_s)
@@ -571,7 +586,7 @@ object CheckGenders extends Pass {
fQ
}
- val has_flipQ = flipQ(tpe(e))
+ val has_flipQ = flipQ(e.tpe)
//println(e)
//println(gender)
//println(desired)
@@ -597,7 +612,7 @@ object CheckGenders extends Pass {
(e) match {
case (e:WRef) => genders(e.name)
case (e:WSubField) =>
- val f = tpe(e.exp).as[BundleType].get.fields.find(f => f.name == e.name).get
+ val f = e.exp.tpe.asInstanceOf[BundleType].fields.find(f => f.name == e.name).get
times(get_gender(e.exp,genders),f.flip)
case (e:WSubIndex) => get_gender(e.exp,genders)
case (e:WSubAccess) => get_gender(e.exp,genders)
@@ -735,7 +750,7 @@ object CheckWidths extends Pass {
}
def check_width_s (s:Statement) : Statement = {
s map (check_width_s) map (check_width_e(get_info(s)))
- def tm (t:Type) : Type = mapr(check_width_w(info(s)) _,t)
+ def tm (t:Type) : Type = mapr(check_width_w(get_info(s)) _,t)
s map (tm)
}
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
index 57782a3c..2e8b53f3 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/passes/ConstProp.scala
@@ -129,7 +129,7 @@ object ConstProp extends Pass {
private def foldComparison(e: DoPrim) = {
def foldIfZeroedArg(x: Expression): Expression = {
- def isUInt(e: Expression): Boolean = tpe(e) match {
+ def isUInt(e: Expression): Boolean = e.tpe match {
case UIntType(_) => true
case _ => false
}
@@ -163,7 +163,7 @@ object ConstProp extends Pass {
def range(e: Expression): Range = e match {
case UIntLiteral(value, _) => Range(value, value)
case SIntLiteral(value, _) => Range(value, value)
- case _ => tpe(e) match {
+ case _ => e.tpe match {
case SIntType(IntWidth(width)) => Range(
min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
max = BigInt(2).pow(width.toInt - 1) - BigInt(1)
@@ -226,7 +226,7 @@ object ConstProp extends Pass {
case Pad => e.args(0) match {
case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(e.consts(0)))
case SIntLiteral(v, _) => SIntLiteral(v, IntWidth(e.consts(0)))
- case _ if long_BANG(tpe(e.args(0))) == e.consts(0) => e.args(0)
+ case _ if long_BANG(e.args(0).tpe) == e.consts(0) => e.args(0)
case _ => e
}
case Bits => e.args(0) match {
@@ -234,9 +234,9 @@ object ConstProp extends Pass {
val hi = e.consts(0).toInt
val lo = e.consts(1).toInt
require(hi >= lo)
- UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e)))
+ UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(e.tpe))
}
- case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match {
+ case x if long_BANG(e.tpe) == long_BANG(x.tpe) => x.tpe match {
case t: UIntType => x
case _ => asUInt(x, e.tpe)
}
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index 921693c7..3d26298a 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -131,8 +131,8 @@ object ExpandWhens extends Pass {
val falseValue = altNetlist.getOrElse(lvalue, defaultValue)
(trueValue, falseValue) match {
case (WInvalid(), WInvalid()) => WInvalid()
- case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, tpe(fv))
- case (tv, WInvalid()) => ValidIf(s.pred, tv, tpe(tv))
+ case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe)
+ case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe)
case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv))
}
case None =>
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala
index 7793c85c..a8fda1bf 100644
--- a/src/main/scala/firrtl/passes/Inline.scala
+++ b/src/main/scala/firrtl/passes/Inline.scala
@@ -5,7 +5,6 @@ package passes
import scala.collection.mutable
import firrtl.Mappers.{ExpMap,StmtMap}
-import firrtl.Utils.WithAs
import firrtl.ir._
import firrtl.passes.{PassException,PassExceptions}
import Annotations.{Loose, Unstable, Annotation, TransID, Named, ModuleName, ComponentName, CircuitName, AnnotationMap}
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index 585598a8..a4c584ed 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -105,15 +105,15 @@ object LowerTypes extends Pass {
require(tail.isEmpty) // there can't be a tail for these
val memType = memDataTypeMap(mem.name)
- if (memType.isGround) {
- Seq(e)
- } else {
- val exps = create_exps(mem.name, memType)
- exps map { e =>
- val loMemName = loweredName(e)
- val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
- mergeRef(loMem, mergeRef(port, field))
- }
+ memType match {
+ case _: GroundType => Seq(e)
+ case _ =>
+ val exps = create_exps(mem.name, memType)
+ exps map { e =>
+ val loMemName = loweredName(e)
+ val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
+ mergeRef(loMem, mergeRef(port, field))
+ }
}
// Fields that need not be replicated for each
// eg. mem.reader.data[0].a
@@ -138,7 +138,7 @@ object LowerTypes extends Pass {
case k: InstanceKind =>
val (root, tail) = splitRef(e)
val name = loweredName(tail)
- WSubField(root, name, tpe(e), gender(e))
+ WSubField(root, name, e.tpe, gender(e))
case k: MemKind =>
val exps = lowerTypesMemExp(e)
if (exps.length > 1)
@@ -146,7 +146,7 @@ object LowerTypes extends Pass {
" to be expanded!")
exps(0)
case k =>
- WRef(loweredName(e), tpe(e), kind(e), gender(e))
+ WRef(loweredName(e), e.tpe, kind(e), gender(e))
}
case e: Mux => e map (lowerTypesExp)
case e: ValidIf => e map (lowerTypesExp)
@@ -158,26 +158,26 @@ object LowerTypes extends Pass {
s map lowerTypesStmt match {
case s: DefWire =>
sinfo = s.info
- if (s.tpe.isGround) {
- s
- } else {
- val exps = create_exps(s.name, s.tpe)
- val stmts = exps map (e => DefWire(s.info, loweredName(e), tpe(e)))
- Block(stmts)
+ s.tpe match {
+ case _: GroundType => s
+ case _ =>
+ val exps = create_exps(s.name, s.tpe)
+ val stmts = exps map (e => DefWire(s.info, loweredName(e), e.tpe))
+ Block(stmts)
}
case s: DefRegister =>
sinfo = s.info
- if (s.tpe.isGround) {
- s map lowerTypesExp
- } else {
- val es = create_exps(s.name, s.tpe)
- val inits = create_exps(s.init) map (lowerTypesExp)
- val clock = lowerTypesExp(s.clock)
- val reset = lowerTypesExp(s.reset)
- val stmts = es zip inits map { case (e, i) =>
- DefRegister(s.info, loweredName(e), tpe(e), clock, reset, i)
- }
- Block(stmts)
+ s.tpe match {
+ case _: GroundType => s map lowerTypesExp
+ case _ =>
+ val es = create_exps(s.name, s.tpe)
+ val inits = create_exps(s.init) map (lowerTypesExp)
+ val clock = lowerTypesExp(s.clock)
+ val reset = lowerTypesExp(s.reset)
+ val stmts = es zip inits map { case (e, i) =>
+ DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i)
+ }
+ Block(stmts)
}
// Could instead just save the type of each Module as it gets processed
case s: WDefInstance =>
@@ -188,7 +188,7 @@ object LowerTypes extends Pass {
val exps = create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE)))
exps map ( e =>
// Flip because inst genders are reversed from Module type
- Field(loweredName(e), toFlip(gender(e)).flip, tpe(e))
+ Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)
)
}
WDefInstance(s.info, s.name, s.module, BundleType(fieldsx))
@@ -197,16 +197,16 @@ object LowerTypes extends Pass {
case s: DefMemory =>
sinfo = s.info
memDataTypeMap += (s.name -> s.dataType)
- if (s.dataType.isGround) {
- s
- } else {
- val exps = create_exps(s.name, s.dataType)
- val stmts = exps map { e =>
- DefMemory(s.info, loweredName(e), tpe(e), s.depth,
- s.writeLatency, s.readLatency, s.readers, s.writers,
- s.readwriters)
- }
- Block(stmts)
+ s.dataType match {
+ case _: GroundType => s
+ case _ =>
+ val exps = create_exps(s.name, s.dataType)
+ val stmts = exps map { e =>
+ DefMemory(s.info, loweredName(e), e.tpe, s.depth,
+ s.writeLatency, s.readLatency, s.readers, s.writers,
+ s.readwriters)
+ }
+ Block(stmts)
}
// wire foo : { a , b }
// node x = foo
@@ -217,7 +217,7 @@ object LowerTypes extends Pass {
// node y = x_a
case s: DefNode =>
sinfo = s.info
- val names = create_exps(s.name, tpe(s.value)) map (lowerTypesExp)
+ val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp)
val exps = create_exps(s.value) map (lowerTypesExp)
val stmts = names zip exps map { case (n, e) =>
DefNode(s.info, loweredName(n), e)
@@ -249,7 +249,7 @@ object LowerTypes extends Pass {
// Lower Ports
val portsx = m.ports flatMap { p =>
val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction)))
- exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), tpe(e)) )
+ exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe) )
}
m match {
case m: ExtModule => m.copy(ports = portsx)
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala
index 0cabc293..f2117761 100644
--- a/src/main/scala/firrtl/passes/PadWidths.scala
+++ b/src/main/scala/firrtl/passes/PadWidths.scala
@@ -2,7 +2,7 @@ package firrtl
package passes
import firrtl.Mappers.{ExpMap, StmtMap}
-import firrtl.Utils.{tpe, long_BANG}
+import firrtl.Utils.long_BANG
import firrtl.PrimOps._
import firrtl.ir._
@@ -10,10 +10,10 @@ import firrtl.ir._
object PadWidths extends Pass {
def name = "Pad Widths"
private def width(t: Type): Int = long_BANG(t).toInt
- private def width(e: Expression): Int = width(tpe(e))
+ private def width(e: Expression): Int = width(e.tpe)
// Returns an expression with the correct integer width
private def fixup(i: Int)(e: Expression) = {
- def tx = tpe(e) match {
+ def tx = e.tpe match {
case t: UIntType => UIntType(IntWidth(i))
case t: SIntType => SIntType(IntWidth(i))
// default case should never be reached
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 7b4f9aa2..6b6dc811 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -103,7 +103,7 @@ object ResolveKinds extends Pass {
def resolve (body:Statement) = {
def resolve_expr (e:Expression):Expression = {
e match {
- case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender)
+ case e:WRef => WRef(e.name,e.tpe,kinds(e.name),e.gender)
case e => e map (resolve_expr)
}
}
@@ -170,11 +170,11 @@ object InferTypes extends Pass {
val types = LinkedHashMap[String,Type]()
def infer_types_e (e:Expression) : Expression = {
e map (infer_types_e) match {
- case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value))
+ case e:ValidIf => ValidIf(e.cond,e.value,e.value.tpe)
case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender)
- case e:WSubField => WSubField(e.exp,e.name,field_type(tpe(e.exp),e.name),e.gender)
- case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(tpe(e.exp)),e.gender)
- case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(tpe(e.exp)),e.gender)
+ case e:WSubField => WSubField(e.exp,e.name,field_type(e.exp.tpe,e.name),e.gender)
+ case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(e.exp.tpe),e.gender)
+ case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(e.exp.tpe),e.gender)
case e:DoPrim => set_primop_type(e)
case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval))
case e:UIntLiteral => e
@@ -246,7 +246,7 @@ object ResolveGenders extends Pass {
case e:WRef => WRef(e.name,e.tpe,e.kind,g)
case e:WSubField => {
val expx =
- field_flip(tpe(e.exp),e.name) match {
+ field_flip(e.exp.tpe,e.name) match {
case Default => resolve_e(g)(e.exp)
case Flip => resolve_e(swap(g))(e.exp)
}
@@ -474,7 +474,7 @@ object InferWidths extends Pass {
case (t:SIntType) => t.width
case ClockType => IntWidth(1)
case (t) => error("No width!"); IntWidth(-1) } }
- def width_BANG (e:Expression) : Width = width_BANG(tpe(e))
+ def width_BANG (e:Expression) : Width = width_BANG(e.tpe)
def reduce_var_widths(c: Circuit, h: LinkedHashMap[String,Width]): Circuit = {
def evaluate(w: Width): Width = {
@@ -549,40 +549,40 @@ object InferWidths extends Pass {
def get_constraints_e (e:Expression) : Expression = {
(e map (get_constraints_e)) match {
case (e:Mux) => {
- constrain(width_BANG(e.cond),ONE)
- constrain(ONE,width_BANG(e.cond))
+ constrain(width_BANG(e.cond),IntWidth(1))
+ constrain(IntWidth(1),width_BANG(e.cond))
e }
case (e) => e }}
def get_constraints (s:Statement) : Statement = {
(s map (get_constraints_e)) match {
case (s:Connect) => {
- val n = get_size(tpe(s.loc))
+ val n = get_size(s.loc.tpe)
val ce_loc = create_exps(s.loc)
val ce_exp = create_exps(s.expr)
for (i <- 0 until n) {
val locx = ce_loc(i)
val expx = ce_exp(i)
- get_flip(tpe(s.loc),i,Default) match {
+ get_flip(s.loc.tpe,i,Default) match {
case Default => constrain(width_BANG(locx),width_BANG(expx))
case Flip => constrain(width_BANG(expx),width_BANG(locx)) }}
s }
case (s:PartialConnect) => {
- val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default)
+ val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
for (x <- ls) {
val locx = create_exps(s.loc)(x._1)
val expx = create_exps(s.expr)(x._2)
- get_flip(tpe(s.loc),x._1,Default) match {
+ get_flip(s.loc.tpe,x._1,Default) match {
case Default => constrain(width_BANG(locx),width_BANG(expx))
case Flip => constrain(width_BANG(expx),width_BANG(locx)) }}
s }
case (s:DefRegister) => {
- constrain(width_BANG(s.reset),ONE)
- constrain(ONE,width_BANG(s.reset))
- get_constraints_t(s.tpe,tpe(s.init),Default)
+ constrain(width_BANG(s.reset),IntWidth(1))
+ constrain(IntWidth(1),width_BANG(s.reset))
+ get_constraints_t(s.tpe,s.init.tpe,Default)
s }
case (s:Conditionally) => {
- v += WGeq(width_BANG(s.pred),ONE)
- v += WGeq(ONE,width_BANG(s.pred))
+ v += WGeq(width_BANG(s.pred),IntWidth(1))
+ v += WGeq(IntWidth(1),width_BANG(s.pred))
s map (get_constraints) }
case (s) => s map (get_constraints) }}
@@ -661,7 +661,7 @@ object ExpandConnects extends Pass {
e map (set_gender) match {
case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name))
case (e:WSubField) => {
- val f = get_field(tpe(e.exp),e.name)
+ val f = get_field(e.exp.tpe,e.name)
val genderx = times(gender(e.exp),f.flip)
WSubField(e.exp,e.name,e.tpe,genderx)
}
@@ -677,7 +677,7 @@ object ExpandConnects extends Pass {
case (s:DefMemory) => { genders(s.name) = MALE; s }
case (s:DefNode) => { genders(s.name) = MALE; s }
case (s:IsInvalid) => {
- val n = get_size(tpe(s.expr))
+ val n = get_size(s.expr.tpe)
val invalids = ArrayBuffer[Statement]()
val exps = create_exps(s.expr)
for (i <- 0 until n) {
@@ -696,14 +696,14 @@ object ExpandConnects extends Pass {
} else Block(invalids)
}
case (s:Connect) => {
- val n = get_size(tpe(s.loc))
+ val n = get_size(s.loc.tpe)
val connects = ArrayBuffer[Statement]()
val locs = create_exps(s.loc)
val exps = create_exps(s.expr)
for (i <- 0 until n) {
val locx = locs(i)
val expx = exps(i)
- val sx = get_flip(tpe(s.loc),i,Default) match {
+ val sx = get_flip(s.loc.tpe,i,Default) match {
case Default => Connect(s.info,locx,expx)
case Flip => Connect(s.info,expx,locx)
}
@@ -712,14 +712,14 @@ object ExpandConnects extends Pass {
Block(connects)
}
case (s:PartialConnect) => {
- val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default)
+ val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
val connects = ArrayBuffer[Statement]()
val locs = create_exps(s.loc)
val exps = create_exps(s.expr)
ls.foreach { x => {
val locx = locs(x._1)
val expx = exps(x._2)
- val sx = get_flip(tpe(s.loc),x._1,Default) match {
+ val sx = get_flip(s.loc.tpe,x._1,Default) match {
case Default => Connect(s.info,locx,expx)
case Flip => Connect(s.info,expx,locx)
}
@@ -755,7 +755,7 @@ object Legalize extends Pass {
def legalizeShiftRight (e: DoPrim): Expression = e.op match {
case Shr => {
val amount = e.consts(0).toInt
- val width = long_BANG(tpe(e.args(0)))
+ val width = long_BANG(e.args(0).tpe)
lazy val msb = width - 1
if (amount >= width) {
e.tpe match {
@@ -771,9 +771,9 @@ object Legalize extends Pass {
case _ => e
}
def legalizeConnect(c: Connect): Statement = {
- val t = tpe(c.loc)
+ val t = c.loc.tpe
val w = long_BANG(t)
- if (w >= long_BANG(tpe(c.expr))) c
+ if (w >= long_BANG(c.expr.tpe)) c
else {
val newType = t match {
case _: UIntType => UIntType(IntWidth(w))
@@ -811,8 +811,8 @@ object VerilogWrap extends Pass {
if (e.op == Tail) {
(a0()) match {
case (e0:DoPrim) => {
- if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),tpe(e))
- else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),tpe(e))
+ if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),e.tpe)
+ else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),e.tpe)
else e
}
case (e0) => e
@@ -913,12 +913,12 @@ object CInferTypes extends Pass {
def infer_types_e (e:Expression) : Expression = {
e map infer_types_e match {
case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType))
- case (e:SubField) => SubField(e.expr,e.name,field_type(tpe(e.expr),e.name))
- case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(tpe(e.expr)))
- case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(tpe(e.expr)))
+ case (e:SubField) => SubField(e.expr,e.name,field_type(e.expr.tpe,e.name))
+ case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(e.expr.tpe))
+ case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(e.expr.tpe))
case (e:DoPrim) => set_primop_type(e)
case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval))
- case (e:ValidIf) => ValidIf(e.cond,e.value,tpe(e.value))
+ case (e:ValidIf) => ValidIf(e.cond,e.value,e.value.tpe)
case (_:UIntLiteral | _:SIntLiteral) => e
}
}
@@ -1067,8 +1067,8 @@ object RemoveCHIRRTL extends Pass {
val e2s = create_exps(e.fval)
(e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2)))
case (e:ValidIf) =>
- create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1)))
- case (e) => (tpe(e)) match {
+ create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe))
+ case (e) => (e.tpe) match {
case (_:GroundType) => Seq(e)
case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
exps ++ create_exps(SubField(e,f.name,f.tpe)))
@@ -1276,7 +1276,7 @@ object RemoveCHIRRTL extends Pass {
case Some(en) => stmts += Connect(s.info,en,one)
}
if (has_write_mport) {
- val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default)
+ val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
val locs = create_exps(get_mask(s.loc))
for (x <- ls ) {
val locx = locs(x._1)
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index a3ce49f7..880d6b1c 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -76,7 +76,7 @@ object RemoveAccesses extends Pass {
def onStmt(s: Statement): Statement = {
def create_temp(e: Expression): (Statement, Expression) = {
val n = namespace.newTemp
- (DefWire(info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e)))
+ (DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e)))
}
/** Replaces a subaccess in a given male expression
@@ -94,9 +94,9 @@ object RemoveAccesses extends Pass {
stmts += wire
rs.zipWithIndex foreach {
case (x, i) if i < temps.size =>
- stmts += Connect(info(s),getTemp(i),x.base)
+ stmts += Connect(get_info(s),getTemp(i),x.base)
case (x, i) =>
- stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt)
+ stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
}
temp
}
diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala
index 1c9674e1..3b6021ed 100644
--- a/src/main/scala/firrtl/passes/SplitExpressions.scala
+++ b/src/main/scala/firrtl/passes/SplitExpressions.scala
@@ -2,7 +2,7 @@ package firrtl
package passes
import firrtl.Mappers.{ExpMap, StmtMap}
-import firrtl.Utils.{tpe, kind, gender, info}
+import firrtl.Utils.{kind, gender, get_info}
import firrtl.ir._
import scala.collection.mutable
@@ -20,18 +20,18 @@ object SplitExpressions extends Pass {
def split(e: Expression): Expression = e match {
case e: DoPrim => {
val name = namespace.newTemp
- v += DefNode(info(s), name, e)
- WRef(name, tpe(e), kind(e), gender(e))
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), gender(e))
}
case e: Mux => {
val name = namespace.newTemp
- v += DefNode(info(s), name, e)
- WRef(name, tpe(e), kind(e), gender(e))
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), gender(e))
}
case e: ValidIf => {
val name = namespace.newTemp
- v += DefNode(info(s), name, e)
- WRef(name, tpe(e), kind(e), gender(e))
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), gender(e))
}
case e => e
}
diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala
index b1a20fdd..d034719a 100644
--- a/src/main/scala/firrtl/passes/Uniquify.scala
+++ b/src/main/scala/firrtl/passes/Uniquify.scala
@@ -109,8 +109,9 @@ object Uniquify extends Pass {
val newName = findValidPrefix(f.name, Seq(""), namespace)
namespace += newName
Field(newName, f.flip, f.tpe)
- } map { f =>
- if (f.tpe.isAggregate) {
+ } map { f => f.tpe match {
+ case _: GroundType => f
+ case _ =>
val tpe = recUniquifyNames(f.tpe, collection.mutable.HashSet())
val elts = enumerateNames(tpe)
// Need leading _ for findValidPrefix, it doesn't add _ for checks
@@ -123,8 +124,6 @@ object Uniquify extends Pass {
}
namespace ++= (elts map (e => LowerTypes.loweredName(prefix +: e)))
Field(prefix, f.flip, tpe)
- } else {
- f
}
}
BundleType(newFields)
@@ -349,7 +348,9 @@ object Uniquify extends Pass {
def uniquifyPorts(m: DefModule): DefModule = {
def uniquifyPorts(ports: Seq[Port]): Seq[Port] = {
- val portsType = BundleType(ports map (_.toField))
+ val portsType = BundleType(ports map {
+ case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe)
+ })
val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet())
val localMap = createNameMapping(portsType, uniquePortsType)
portNameMap += (m.name -> localMap)