aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/passes/InferTypes.scala15
-rw-r--r--src/main/scala/firrtl/passes/ResolveFlows.scala5
2 files changed, 11 insertions, 9 deletions
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala
index 5524e0ea..6cc9f2b9 100644
--- a/src/main/scala/firrtl/passes/InferTypes.scala
+++ b/src/main/scala/firrtl/passes/InferTypes.scala
@@ -20,7 +20,6 @@ object InferTypes extends Pass {
def run(c: Circuit): Circuit = {
val namespace = Namespace()
- val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
def remove_unknowns_b(b: Bound): Bound = b match {
case UnknownBound => VarBound(namespace.newName("b"))
@@ -40,6 +39,11 @@ object InferTypes extends Pass {
}
}
+ // we first need to remove the unknown widths and bounds from all ports,
+ // as their type will determine the module types
+ val portsKnown = c.modules.map(_.map{ p: Port => p.copy(tpe = remove_unknowns(p.tpe)) })
+ val mtypes = portsKnown.map(m => m.name -> module_type(m)).toMap
+
def infer_types_e(types: TypeLookup)(e: Expression): Expression =
e map infer_types_e(types) match {
case e: WRef => e copy (tpe = types(e.name))
@@ -71,9 +75,10 @@ object InferTypes extends Pass {
types(sx.name) = t
sx copy (tpe = t) map infer_types_e(types)
case sx: DefMemory =>
- val t = remove_unknowns(MemPortUtils.memType(sx))
- types(sx.name) = t
- sx copy (dataType = remove_unknowns(sx.dataType))
+ // we need to remove the unknowns from the data type so that all ports get the same VarWidth
+ val knownDataType = sx.copy(dataType = remove_unknowns(sx.dataType))
+ types(sx.name) = MemPortUtils.memType(knownDataType)
+ knownDataType
case sx => sx map infer_types_s(types) map infer_types_e(types)
}
@@ -88,7 +93,7 @@ object InferTypes extends Pass {
m map infer_types_p(types) map infer_types_s(types)
}
- c copy (modules = c.modules map infer_types)
+ c.copy(modules = portsKnown.map(infer_types))
}
}
diff --git a/src/main/scala/firrtl/passes/ResolveFlows.scala b/src/main/scala/firrtl/passes/ResolveFlows.scala
index c3455327..85a0a26f 100644
--- a/src/main/scala/firrtl/passes/ResolveFlows.scala
+++ b/src/main/scala/firrtl/passes/ResolveFlows.scala
@@ -9,10 +9,7 @@ import firrtl.options.Dependency
object ResolveFlows extends Pass {
- override def prerequisites =
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.Uniquify) ) ++ firrtl.stage.Forms.WorkingIR
+ override def prerequisites = Seq(Dependency(passes.InferTypes)) ++ firrtl.stage.Forms.WorkingIR
override def invalidates(a: Transform) = false