diff options
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/passes/InferTypes.scala | 15 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ResolveFlows.scala | 5 |
2 files changed, 11 insertions, 9 deletions
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 5524e0ea..6cc9f2b9 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -20,7 +20,6 @@ object InferTypes extends Pass { def run(c: Circuit): Circuit = { val namespace = Namespace() - val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap def remove_unknowns_b(b: Bound): Bound = b match { case UnknownBound => VarBound(namespace.newName("b")) @@ -40,6 +39,11 @@ object InferTypes extends Pass { } } + // we first need to remove the unknown widths and bounds from all ports, + // as their type will determine the module types + val portsKnown = c.modules.map(_.map{ p: Port => p.copy(tpe = remove_unknowns(p.tpe)) }) + val mtypes = portsKnown.map(m => m.name -> module_type(m)).toMap + def infer_types_e(types: TypeLookup)(e: Expression): Expression = e map infer_types_e(types) match { case e: WRef => e copy (tpe = types(e.name)) @@ -71,9 +75,10 @@ object InferTypes extends Pass { types(sx.name) = t sx copy (tpe = t) map infer_types_e(types) case sx: DefMemory => - val t = remove_unknowns(MemPortUtils.memType(sx)) - types(sx.name) = t - sx copy (dataType = remove_unknowns(sx.dataType)) + // we need to remove the unknowns from the data type so that all ports get the same VarWidth + val knownDataType = sx.copy(dataType = remove_unknowns(sx.dataType)) + types(sx.name) = MemPortUtils.memType(knownDataType) + knownDataType case sx => sx map infer_types_s(types) map infer_types_e(types) } @@ -88,7 +93,7 @@ object InferTypes extends Pass { m map infer_types_p(types) map infer_types_s(types) } - c copy (modules = c.modules map infer_types) + c.copy(modules = portsKnown.map(infer_types)) } } diff --git a/src/main/scala/firrtl/passes/ResolveFlows.scala b/src/main/scala/firrtl/passes/ResolveFlows.scala index c3455327..85a0a26f 100644 --- a/src/main/scala/firrtl/passes/ResolveFlows.scala +++ b/src/main/scala/firrtl/passes/ResolveFlows.scala @@ -9,10 +9,7 @@ import firrtl.options.Dependency object ResolveFlows extends Pass { - override def prerequisites = - Seq( Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.Uniquify) ) ++ firrtl.stage.Forms.WorkingIR + override def prerequisites = Seq(Dependency(passes.InferTypes)) ++ firrtl.stage.Forms.WorkingIR override def invalidates(a: Transform) = false |
