aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-08-30 19:02:40 -0700
committerDonggyu Kim2016-09-08 13:15:28 -0700
commitb020d66212a7381261231ba71c47010e64c6782f (patch)
tree4b5d226f65846335d82f2edc67ef0196e24968a2 /src
parente995f4993b21778c0568a56a31ed970948e39cf8 (diff)
refactor InferTypes
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/InferTypes.scala296
1 files changed, 108 insertions, 188 deletions
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala
index 23bc7e11..b36298e8 100644
--- a/src/main/scala/firrtl/passes/InferTypes.scala
+++ b/src/main/scala/firrtl/passes/InferTypes.scala
@@ -27,215 +27,135 @@ MODIFICATIONS.
package firrtl.passes
-import com.typesafe.scalalogging.LazyLogging
-import java.nio.file.{Paths, Files}
-
-// Datastructures
-import scala.collection.mutable.LinkedHashMap
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.ArrayBuffer
-
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
-import firrtl.PrimOps._
-import firrtl.WrappedExpression._
object InferTypes extends Pass {
- private var mname = ""
def name = "Infer Types"
- def set_type (s:Statement, t:Type) : Statement = {
- s match {
- case s:DefWire => DefWire(s.info,s.name,t)
- case s:DefRegister => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init)
- case s:DefMemory => DefMemory(s.info,s.name,t,s.depth,s.writeLatency,s.readLatency,s.readers,s.writers,s.readwriters)
- case s:DefNode => s
- }
- }
- def remove_unknowns_w (w:Width)(implicit namespace: Namespace):Width = {
- w match {
- case UnknownWidth => VarWidth(namespace.newName("w"))
- case w => w
+ type TypeMap = collection.mutable.LinkedHashMap[String, Type]
+
+ def run(c: Circuit): Circuit = {
+ val namespace = Namespace()
+ val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
+
+ def remove_unknowns_w(w: Width): Width = w match {
+ case UnknownWidth => VarWidth(namespace.newName("w"))
+ case w => w
}
- }
- def remove_unknowns (t:Type)(implicit n: Namespace): Type = mapr(remove_unknowns_w _,t)
- def run (c:Circuit): Circuit = {
- val module_types = LinkedHashMap[String,Type]()
- implicit val wnamespace = Namespace()
- def infer_types (m:DefModule) : DefModule = {
- val types = LinkedHashMap[String,Type]()
- def infer_types_e (e:Expression) : Expression = {
- e map (infer_types_e) match {
- case e:ValidIf => ValidIf(e.cond,e.value,e.value.tpe)
- case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender)
- case e:WSubField => WSubField(e.exp,e.name,field_type(e.exp.tpe,e.name),e.gender)
- case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(e.exp.tpe),e.gender)
- case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(e.exp.tpe),e.gender)
- case e:DoPrim => set_primop_type(e)
- case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval))
- case e:UIntLiteral => e
- case e:SIntLiteral => e
- }
- }
- def infer_types_s (s:Statement) : Statement = {
- s match {
- case s:DefRegister => {
- val t = remove_unknowns(get_type(s))
- types(s.name) = t
- set_type(s,t) map (infer_types_e)
- }
- case s:DefWire => {
- val sx = s map(infer_types_e)
- val t = remove_unknowns(get_type(sx))
- types(s.name) = t
- set_type(sx,t)
- }
- case s:DefNode => {
- val sx = s map (infer_types_e)
- val t = remove_unknowns(get_type(sx))
- types(s.name) = t
- set_type(sx,t)
- }
- case s:DefMemory => {
- val t = remove_unknowns(get_type(s))
- types(s.name) = t
- val dt = remove_unknowns(s.dataType)
- set_type(s,dt)
- }
- case s:WDefInstance => {
- types(s.name) = module_types(s.module)
- WDefInstance(s.info,s.name,s.module,module_types(s.module))
- }
- case s => s map (infer_types_s) map (infer_types_e)
- }
- }
- mname = m.name
- m.ports.foreach(p => types(p.name) = p.tpe)
- m match {
- case m:Module => Module(m.info,m.name,m.ports,infer_types_s(m.body))
- case m:ExtModule => m
+ def remove_unknowns(t: Type): Type =
+ t map remove_unknowns map remove_unknowns_w
+
+ def infer_types_e(types: TypeMap)(e: Expression): Expression =
+ e map infer_types_e(types) match {
+ case e: WRef => e copy (tpe = types(e.name))
+ case e: WSubField => e copy (tpe = field_type(e.exp.tpe, e.name))
+ case e: WSubIndex => e copy (tpe = sub_type(e.exp.tpe))
+ case e: WSubAccess => e copy (tpe = sub_type(e.exp.tpe))
+ case e: DoPrim => PrimOps.set_primop_type(e)
+ case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval))
+ case e: ValidIf => e copy (tpe = e.value.tpe)
+ case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
+
+ def infer_types_s(types: TypeMap)(s: Statement): Statement = s match {
+ case s: WDefInstance =>
+ val t = mtypes(s.module)
+ types(s.name) = t
+ s copy (tpe = t)
+ case s: DefWire =>
+ val t = remove_unknowns(get_type(s))
+ types(s.name) = t
+ s copy (tpe = t)
+ case s: DefNode =>
+ val sx = s map infer_types_e(types)
+ val t = remove_unknowns(get_type(sx))
+ types(s.name) = t
+ sx map infer_types_e(types)
+ case s: DefRegister =>
+ val t = remove_unknowns(get_type(s))
+ types(s.name) = t
+ s copy (tpe = t) map infer_types_e(types)
+ case s: DefMemory =>
+ val t = remove_unknowns(get_type(s))
+ types(s.name) = t
+ s copy (dataType = remove_unknowns(s.dataType))
+ case s => s map infer_types_s(types) map infer_types_e(types)
}
- val modulesx = c.modules.map {
- m => {
- mname = m.name
- val portsx = m.ports.map(p => Port(p.info,p.name,p.direction,remove_unknowns(p.tpe)))
- m match {
- case m:Module => Module(m.info,m.name,portsx,m.body)
- case m:ExtModule => ExtModule(m.info,m.name,portsx)
- }
- }
+ def infer_types_p(types: TypeMap)(p: Port): Port = {
+ val t = remove_unknowns(p.tpe)
+ types(p.name) = t
+ p copy (tpe = t)
+ }
+
+ def infer_types(m: DefModule): DefModule = {
+ val types = new TypeMap
+ m map infer_types_p(types) map infer_types_s(types)
}
- modulesx.foreach(m => module_types(m.name) = module_type(m))
- Circuit(c.info,modulesx.map({m => mname = m.name; infer_types(m)}) , c.main )
+
+ c copy (modules = (c.modules map infer_types))
}
}
object CInferTypes extends Pass {
def name = "CInfer Types"
- var mname = ""
- def set_type (s:Statement, t:Type) : Statement = {
- (s) match {
- case (s:DefWire) => DefWire(s.info,s.name,t)
- case (s:DefRegister) => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init)
- case (s:CDefMemory) => CDefMemory(s.info,s.name,t,s.size,s.seq)
- case (s:CDefMPort) => CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction)
- case (s:DefNode) => s
- }
- }
-
- def to_field (p:Port) : Field = {
- if (p.direction == Output) Field(p.name,Default,p.tpe)
- else if (p.direction == Input) Field(p.name,Flip,p.tpe)
- else error("Shouldn't be here"); Field(p.name,Flip,p.tpe)
- }
- def module_type (m:DefModule) : Type = BundleType(m.ports.map(p => to_field(p)))
- def field_type (v:Type,s:String) : Type = {
- (v) match {
- case (v:BundleType) => {
- val ft = v.fields.find(p => p.name == s)
- if (ft != None) ft.get.tpe
- else UnknownType
+ type TypeMap = collection.mutable.LinkedHashMap[String, Type]
+
+ def run(c: Circuit): Circuit = {
+ val namespace = Namespace()
+ val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
+
+ def infer_types_e(types: TypeMap)(e: Expression) : Expression =
+ e map infer_types_e(types) match {
+ case (e: Reference) => e copy (tpe = (types getOrElse (e.name, UnknownType)))
+ case (e: SubField) => e copy (tpe = field_type(e.expr.tpe, e.name))
+ case (e: SubIndex) => e copy (tpe = sub_type(e.expr.tpe))
+ case (e: SubAccess) => e copy (tpe = sub_type(e.expr.tpe))
+ case (e: DoPrim) => PrimOps.set_primop_type(e)
+ case (e: Mux) => e copy (tpe = mux_type(e.tval,e.tval))
+ case (e: ValidIf) => e copy (tpe = e.value.tpe)
+ case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
- case (v) => UnknownType
- }
- }
- def sub_type (v:Type) : Type =
- (v) match {
- case (v:VectorType) => v.tpe
- case (v) => UnknownType
+
+ def infer_types_s(types: TypeMap)(s: Statement): Statement = s match {
+ case (s: DefRegister) =>
+ types(s.name) = s.tpe
+ s map infer_types_e(types)
+ case (s: DefWire) =>
+ types(s.name) = s.tpe
+ s
+ case (s: DefNode) =>
+ types(s.name) = get_type(s)
+ s
+ case (s: DefMemory) =>
+ types(s.name) = get_type(s)
+ s
+ case (s: CDefMPort) =>
+ val t = types getOrElse(s.mem, UnknownType)
+ types(s.name) = t
+ s copy (tpe = t)
+ case (s: CDefMemory) =>
+ types(s.name) = s.tpe
+ s
+ case (s: DefInstance) =>
+ types(s.name) = mtypes(s.module)
+ s
+ case (s) => s map infer_types_s(types) map infer_types_e(types)
}
- def run (c:Circuit) : Circuit = {
- val module_types = LinkedHashMap[String,Type]()
- def infer_types (m:DefModule) : DefModule = {
- val types = LinkedHashMap[String,Type]()
- def infer_types_e (e:Expression) : Expression = {
- e map infer_types_e match {
- case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType))
- case (e:SubField) => SubField(e.expr,e.name,field_type(e.expr.tpe,e.name))
- case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(e.expr.tpe))
- case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(e.expr.tpe))
- case (e:DoPrim) => set_primop_type(e)
- case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval))
- case (e:ValidIf) => ValidIf(e.cond,e.value,e.value.tpe)
- case (_:UIntLiteral | _:SIntLiteral) => e
- }
- }
- def infer_types_s (s:Statement) : Statement = {
- s match {
- case (s:DefRegister) => {
- types(s.name) = s.tpe
- s map infer_types_e
- s
- }
- case (s:DefWire) => {
- types(s.name) = s.tpe
- s
- }
- case (s:DefNode) => {
- val sx = s map infer_types_e
- val t = get_type(sx)
- types(s.name) = t
- sx
- }
- case (s:DefMemory) => {
- types(s.name) = get_type(s)
- s
- }
- case (s:CDefMPort) => {
- val t = types.getOrElse(s.mem,UnknownType)
- types(s.name) = t
- CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction)
- }
- case (s:CDefMemory) => {
- types(s.name) = s.tpe
- s
- }
- case (s:DefInstance) => {
- types(s.name) = module_types.getOrElse(s.module,UnknownType)
- s
- }
- case (s) => s map infer_types_s map infer_types_e
- }
- }
- for (p <- m.ports) {
- types(p.name) = p.tpe
- }
- m match {
- case (m:Module) => Module(m.info,m.name,m.ports,infer_types_s(m.body))
- case (m:ExtModule) => m
- }
+
+ def infer_types_p(types: TypeMap)(p: Port): Port = {
+ types(p.name) = p.tpe
+ p
}
-
- //; MAIN
- for (m <- c.modules) {
- module_types(m.name) = module_type(m)
+
+ def infer_types(m: DefModule): DefModule = {
+ val types = new TypeMap
+ m map infer_types_p(types) map infer_types_s(types)
}
- val modulesx = c.modules.map(m => infer_types(m))
- Circuit(c.info, modulesx, c.main)
+
+ c copy (modules = (c.modules map infer_types))
}
}