aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack2015-10-15 13:50:36 -0700
committerJack2015-10-15 13:50:36 -0700
commit7a7936c8fbddbffc1c4775fafeb5106ba1002dd4 (patch)
treebc3cb4d8efeb3243a63f80d2d25e9ee6282215ff
parentedd57efbadf493b331e69c8686662500fe859372 (diff)
Added infer-types pass, seems to work. Added infer-types error checking, modified Logger slightly, added Primops object for utility functions, minor changes in Utils
-rw-r--r--Makefile1
-rw-r--r--src/main/scala/firrtl/DebugUtils.scala7
-rw-r--r--src/main/scala/firrtl/IR.scala6
-rw-r--r--src/main/scala/firrtl/Passes.scala102
-rw-r--r--src/main/scala/firrtl/Primops.scala69
-rw-r--r--src/main/scala/firrtl/Test.scala12
-rw-r--r--src/main/scala/firrtl/Utils.scala20
7 files changed, 199 insertions, 18 deletions
diff --git a/Makefile b/Makefile
index 357c3189..4e2e3b17 100644
--- a/Makefile
+++ b/Makefile
@@ -76,6 +76,7 @@ build-scala:
test-scala:
cd $(test_dir)/parser && lit -v . --path=$(root_dir)/utils/bin/
+ cd $(test_dir)/passes/infer-types && lit -v . --path=$(root_dir)/utils/bin/
set-scala:
ln -f -s $(root_dir)/utils/bin/firrtl-scala $(root_dir)/utils/bin/firrtl
diff --git a/src/main/scala/firrtl/DebugUtils.scala b/src/main/scala/firrtl/DebugUtils.scala
index 80c0d240..01fe4fe4 100644
--- a/src/main/scala/firrtl/DebugUtils.scala
+++ b/src/main/scala/firrtl/DebugUtils.scala
@@ -24,6 +24,9 @@ private object DebugUtils {
val circuitEnable = printVars contains 'circuit
val debugFlags = printVars.map(_ -> true).toMap.withDefaultValue(false)
+ def println(message: => String){
+ writer.println(message)
+ }
def error(message: => String){
if (errorEnable) writer.println(message.split("\n").map("[error] " + _).mkString("\n"))
}
@@ -39,8 +42,8 @@ private object DebugUtils {
def trace(message: => String){
if (traceEnable) writer.println(message.split("\n").map("[trace] " + _).mkString("\n"))
}
- def printDebug(circuit: Circuit){
- if (circuitEnable) this.debug(circuit.serialize(debugFlags))
+ def printlnDebug(circuit: Circuit){
+ if (circuitEnable) this.println(circuit.serialize(debugFlags))
}
// Used if not autoflushing
def flush() = writer.flush()
diff --git a/src/main/scala/firrtl/IR.scala b/src/main/scala/firrtl/IR.scala
index 7905612e..bd9bd484 100644
--- a/src/main/scala/firrtl/IR.scala
+++ b/src/main/scala/firrtl/IR.scala
@@ -1,3 +1,9 @@
+
+/* TODO
+ * - Should FileInfo be a FIRRTL node?
+ *
+ */
+
package firrtl
import scala.collection.Seq
diff --git a/src/main/scala/firrtl/Passes.scala b/src/main/scala/firrtl/Passes.scala
new file mode 100644
index 00000000..4b31b1ff
--- /dev/null
+++ b/src/main/scala/firrtl/Passes.scala
@@ -0,0 +1,102 @@
+
+package firrtl
+
+import Utils._
+import DebugUtils._
+import Primops._
+
+object Passes {
+
+ private def toField(p: Port)(implicit logger: Logger): Field = {
+ logger.trace(s"toField called on port ${p.serialize}")
+ p.dir match {
+ case Input => Field(p.name, Reverse, p.tpe)
+ case Output => Field(p.name, Default, p.tpe)
+ }
+ }
+
+ /** INFER TYPES
+ *
+ * This pass infers the type field in all IR nodes by updating
+ * and passing an environment to all statements in pre-order
+ * traversal, and resolving types in expressions in post-
+ * order traversal.
+ * Type propagation for primary ops are defined here.
+ * Notable cases: LetRec requires updating environment before
+ * resolving the subexpressions in its elements.
+ * Type errors are not checked in this pass, as this is
+ * postponed for a later/earlier pass.
+ */
+ // input -> flip
+ private type TypeMap = Map[String, Type]
+ private val TypeMap = Map[String, Type]().withDefaultValue(UnknownType)
+ private def getBundleSubtype(t: Type, name: String): Type = {
+ t match {
+ case b: BundleType => {
+ val tpe = b.fields.find( _.name == name )
+ if (tpe.isEmpty) UnknownType
+ else tpe.get.tpe
+ }
+ case _ => UnknownType
+ }
+ }
+ private def getVectorSubtype(t: Type): Type = t.getType // Added for clarity
+ // TODO Add genders
+ private def inferExpTypes(typeMap: TypeMap)(exp: Exp)(implicit logger: Logger): Exp = {
+ logger.trace(s"inferTypes called on ${exp.getClass.getSimpleName}")
+ exp.map(inferExpTypes(typeMap)) match {
+ case e: UIntValue => e
+ case e: SIntValue => e
+ case e: Ref => Ref(e.name, typeMap(e.name))
+ case e: Subfield => Subfield(e.exp, e.name, getBundleSubtype(e.exp.getType, e.name))
+ case e: Index => Index(e.exp, e.value, getVectorSubtype(e.exp.getType))
+ case e: DoPrimOp => lowerAndTypePrimop(e)
+ case e: Exp => e
+ }
+ }
+ private def inferTypes(typeMap: TypeMap, stmt: Stmt)(implicit logger: Logger): (Stmt, TypeMap) = {
+ logger.trace(s"inferTypes called on ${stmt.getClass.getSimpleName} ")
+ stmt.map(inferExpTypes(typeMap)) match {
+ case b: Block => {
+ var tMap = typeMap
+ // TODO FIXME is map correctly called in sequential order
+ val body = b.stmts.map { s =>
+ val ret = inferTypes(tMap, s)
+ tMap = ret._2
+ ret._1
+ }
+ (Block(body), tMap)
+ }
+ case s: DefWire => (s, typeMap ++ Map(s.name -> s.tpe))
+ case s: DefReg => (s, typeMap ++ Map(s.name -> s.tpe))
+ case s: DefMemory => (s, typeMap ++ Map(s.name -> s.tpe))
+ case s: DefInst => (s, typeMap ++ Map(s.name -> s.module.getType))
+ case s: DefNode => (s, typeMap ++ Map(s.name -> s.value.getType))
+ case s: DefPoison => (s, typeMap ++ Map(s.name -> s.tpe))
+ case s: DefAccessor => (s, typeMap ++ Map(s.name -> getVectorSubtype(s.source.getType)))
+ case s: When => { // TODO Check: Assuming else block won't see when scope
+ val (conseq, cMap) = inferTypes(typeMap, s.conseq)
+ val (alt, aMap) = inferTypes(typeMap, s.alt)
+ (When(s.info, s.pred, conseq, alt), cMap ++ aMap)
+ }
+ case s: Stmt => (s, typeMap)
+ }
+ }
+ private def inferTypes(typeMap: TypeMap, m: Module)(implicit logger: Logger): Module = {
+ logger.trace(s"inferTypes called on module ${m.name}")
+
+ val pTypeMap = m.ports.map( p => p.name -> p.tpe ).toMap
+
+ Module(m.info, m.name, m.ports, inferTypes(typeMap ++ pTypeMap, m.stmt)._1)
+ }
+ def inferTypes(c: Circuit)(implicit logger: Logger): Circuit = {
+ logger.trace(s"inferTypes called on circuit ${c.name}")
+
+ // initialize typeMap with each module of circuit mapped to their bundled IO (ports converted to fields)
+ val typeMap = c.modules.map(m => m.name -> BundleType(m.ports.map(toField(_)))).toMap
+
+ //val typeMap = c.modules.flatMap(buildTypeMap).toMap
+ Circuit(c.info, c.name, c.modules.map(inferTypes(typeMap, _)))
+ }
+
+}
diff --git a/src/main/scala/firrtl/Primops.scala b/src/main/scala/firrtl/Primops.scala
new file mode 100644
index 00000000..5301390c
--- /dev/null
+++ b/src/main/scala/firrtl/Primops.scala
@@ -0,0 +1,69 @@
+
+package firrtl
+
+import Utils._
+import DebugUtils._
+
+object Primops {
+
+ // Borrowed from Stanza implementation
+ def lowerAndTypePrimop(e: DoPrimOp)(implicit logger: Logger): DoPrimOp = {
+ def uAnd(op1: Exp, op2: Exp): Type = {
+ (op1.getType, op2.getType) match {
+ case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth)
+ case (t1: SIntType, t2) => SIntType(UnknownWidth)
+ case (t1, t2: SIntType) => SIntType(UnknownWidth)
+ case _ => UnknownType
+ }
+ }
+ def ofType(op: Exp): Type = {
+ op.getType match {
+ case t: UIntType => UIntType(UnknownWidth)
+ case t: SIntType => SIntType(UnknownWidth)
+ case _ => UnknownType
+ }
+ }
+
+ logger.debug(s"lowerAndTypePrimop on ${e.op.getClass.getSimpleName}")
+ val tpe = e.op match {
+ case Add => uAnd(e.args(0), e.args(1))
+ case Sub => SIntType(UnknownWidth)
+ case Addw => uAnd(e.args(0), e.args(1))
+ case Subw => uAnd(e.args(0), e.args(1))
+ case Mul => uAnd(e.args(0), e.args(1))
+ case Div => uAnd(e.args(0), e.args(1))
+ case Mod => ofType(e.args(0))
+ case Quo => uAnd(e.args(0), e.args(1))
+ case Rem => ofType(e.args(1))
+ case Lt => UIntType(UnknownWidth)
+ case Leq => UIntType(UnknownWidth)
+ case Gt => UIntType(UnknownWidth)
+ case Geq => UIntType(UnknownWidth)
+ case Eq => UIntType(UnknownWidth)
+ case Neq => UIntType(UnknownWidth)
+ case Mux => ofType(e.args(1))
+ case Pad => ofType(e.args(0))
+ case AsUInt => UIntType(UnknownWidth)
+ case AsSInt => SIntType(UnknownWidth)
+ case Shl => ofType(e.args(0))
+ case Shr => ofType(e.args(0))
+ case Dshl => ofType(e.args(0))
+ case Dshr => ofType(e.args(0))
+ case Cvt => SIntType(UnknownWidth)
+ case Neg => SIntType(UnknownWidth)
+ case Not => ofType(e.args(0))
+ case And => ofType(e.args(0))
+ case Or => ofType(e.args(0))
+ case Xor => ofType(e.args(0))
+ case Andr => UIntType(UnknownWidth)
+ case Orr => UIntType(UnknownWidth)
+ case Xorr => UIntType(UnknownWidth)
+ case Cat => UIntType(UnknownWidth)
+ case Bit => UIntType(UnknownWidth)
+ case Bits => UIntType(UnknownWidth)
+ case _ => ???
+ }
+ DoPrimOp(e.op, e.args, e.consts, tpe)
+ }
+
+}
diff --git a/src/main/scala/firrtl/Test.scala b/src/main/scala/firrtl/Test.scala
index 86c3616a..3a89aeef 100644
--- a/src/main/scala/firrtl/Test.scala
+++ b/src/main/scala/firrtl/Test.scala
@@ -3,6 +3,7 @@ package firrtl
import java.io._
import Utils._
import DebugUtils._
+import Passes._
object Test
{
@@ -18,22 +19,25 @@ object Test
val writer = new PrintWriter(new File(output))
writer.write(ast.serialize())
writer.close()
- logger.printDebug(ast)
+ logger.printlnDebug(ast)
}
private def verilog(input: String, output: String)(implicit logger: Logger)
{
logger.warn("Verilog compiler not fully implemented")
val ast = time("parse"){ Parser.parse(input) }
// Execute passes
- //val ast2 = time("inferTypes"){ inferTypes(ast) }
- val ast2 = ast
+
+ logger.println("Infer Types")
+ val ast2 = time("inferTypes"){ inferTypes(ast) }
+ logger.printlnDebug(ast2)
+ logger.println("Finished Infer Types")
+ //val ast2 = ast
// Output
val writer = new PrintWriter(new File(output))
var outString = time("serialize"){ ast2.serialize() }
writer.write(outString)
writer.close()
- logger.printDebug(ast2)
}
def main(args: Array[String])
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 0cf19f04..f024edc2 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -8,6 +8,7 @@
package firrtl
import scala.collection.mutable.StringBuilder
+import java.io.PrintWriter
//import scala.reflect.runtime.universe._
object Utils {
@@ -21,8 +22,7 @@ object Utils {
var str = ""
if (flags('types)) {
val tpe = node.getType
- //if( tpe != UnknownType ) str += s"@<t:${tpe.wipeWidth.serialize}>"
- str += s"@<t:${tpe.wipeWidth.serialize}>"
+ if( tpe != UnknownType ) str += s"@<t:${tpe.wipeWidth.serialize}>"
}
str
}
@@ -46,8 +46,6 @@ object Utils {
case p: Port => p.getType
case _ => UnknownType
}
-
- //def foreach
}
implicit class PrimOpUtils(op: PrimOp) {
@@ -209,11 +207,11 @@ object Utils {
def getType(): Type =
stmt match {
- case w: DefWire => w.tpe
- case r: DefReg => r.tpe
- case m: DefMemory => m.tpe
- case p: DefPoison => p.tpe
- case s: Stmt => UnknownType
+ case s: DefWire => s.tpe
+ case s: DefReg => s.tpe
+ case s: DefMemory => s.tpe
+ case s: DefPoison => s.tpe
+ case _ => UnknownType
}
}
@@ -256,11 +254,9 @@ object Utils {
case t: BundleType => s"{${t.fields.map(_.serialize).mkString(commas)}}"
case t: VectorType => s"${t.tpe.serialize}[${t.size}]"
}
- //s + debug(t)
- s
+ s + debug(t)
}
- // TODO how does this work?
def getType(): Type =
t match {
case v: VectorType => v.tpe