aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-08-30 18:51:17 -0700
committerDonggyu Kim2016-09-07 16:58:06 -0700
commitd7bf6fb7b415d35f967d247119b8975c3dc885a3 (patch)
treebe6afdc75f2f209f4a412d5aafae5015da98cc2a /src
parent296a65ebb895d100c3cbde6df7c0303d6942e5d5 (diff)
refactor checks
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Utils.scala53
-rw-r--r--src/main/scala/firrtl/passes/CheckChirrtl.scala189
-rw-r--r--src/main/scala/firrtl/passes/CheckInitialization.scala38
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala1168
-rw-r--r--src/test/scala/firrtlTests/CheckInitializationSpec.scala4
5 files changed, 663 insertions, 789 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index ea8ff4b7..29c37294 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -662,3 +662,56 @@ class MemoizedHash[T](val t: T) {
case _ => false
}
}
+
+/**
+ * Maintains a one to many graph of each modules instantiated child module.
+ * This graph can be searched for a path from a child module back to one of
+ * it's parents. If one is found a recursive loop has happened
+ * The graph is a map between the name of a node to set of names of that nodes children
+ */
+class ModuleGraph {
+ val nodes = HashMap[String, HashSet[String]]()
+
+ /**
+ * Add a child to a parent node
+ * A parent node is created if it does not already exist
+ *
+ * @param parent module that instantiates another module
+ * @param child module instantiated by parent
+ * @return a list indicating a path from child to parent, empty if no such path
+ */
+ def add(parent: String, child: String): List[String] = {
+ val childSet = nodes.getOrElseUpdate(parent, new HashSet[String])
+ childSet += child
+ pathExists(child, parent, List(child, parent))
+ }
+
+ /**
+ * Starting at the name of a given child explore the tree of all children in depth first manner.
+ * Return the first path (a list of strings) that goes from child to parent,
+ * or an empty list of no such path is found.
+ *
+ * @param child starting name
+ * @param parent name to find in children (recursively)
+ * @param path
+ * @return
+ */
+ def pathExists(child: String, parent: String, path: List[String] = Nil): List[String] = {
+ nodes.get(child) match {
+ case Some(children) =>
+ if(children(parent)) {
+ parent :: path
+ }
+ else {
+ children.foreach { grandchild =>
+ val newPath = pathExists(grandchild, parent, grandchild :: path)
+ if(newPath.nonEmpty) {
+ return newPath
+ }
+ }
+ Nil
+ }
+ case _ => Nil
+ }
+ }
+}
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala
index e0e7c57a..2ab8749b 100644
--- a/src/main/scala/firrtl/passes/CheckChirrtl.scala
+++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala
@@ -30,8 +30,7 @@ package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
// Datastructures
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
import firrtl._
import firrtl.ir._
@@ -44,122 +43,120 @@ import firrtl.WrappedType._
object CheckChirrtl extends Pass with LazyLogging {
def name = "Chirrtl Check"
+ class NotUniqueException(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: [module ${mname}] Reference ${name} does not have a unique name.")
+ class InvalidLOCException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Invalid connect to an expression that is not a reference or a WritePort.")
+ class UndeclaredReferenceException(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: [module ${mname}] Reference ${name} is not declared.")
+ class MemWithFlipException(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: [module ${mname}] Memory ${name} cannot be a bundle type with flips.")
+ class InvalidAccessException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Invalid access to non-reference.")
+ class ModuleNotDefinedException(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: Module ${name} is not defined.")
+ class NegWidthException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Width cannot be negative or zero.")
+ class NegVecSizeException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Vector type size cannot be negative.")
+ class NegMemSizeException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Memory size cannot be negative or zero.")
+ class NoTopModuleException(info: Info, name: String) extends PassException(
+ s"${info}: A single module must be named ${name}.")
+
// TODO FIXME
// - Do we need to check for uniquness on port names?
def run (c: Circuit): Circuit = {
- var mname: String = ""
- var sinfo: Info = NoInfo
-
- class NotUniqueException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} does not have a unique name.")
- class InvalidLOCException extends PassException(s"${sinfo}: [module ${mname}] Invalid connect to an expression that is not a reference or a WritePort.")
- class UndeclaredReferenceException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} is not declared.")
- class MemWithFlipException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Memory ${name} cannot be a bundle type with flips.")
- class InvalidAccessException extends PassException(s"${sinfo}: [module ${mname}] Invalid access to non-reference.")
- class NoTopModuleException(name: String) extends PassException(s"${sinfo}: A single module must be named ${name}.")
- class ModuleNotDefinedException(name: String) extends PassException(s"${sinfo}: Module ${name} is not defined.")
- class NegWidthException extends PassException(s"${sinfo}: [module ${mname}] Width cannot be negative or zero.")
- class NegVecSizeException extends PassException(s"${sinfo}: [module ${mname}] Vector type size cannot be negative.")
- class NegMemSizeException extends PassException(s"${sinfo}: [module ${mname}] Memory size cannot be negative or zero.")
-
val errors = new Errors()
- def checkValidLoc(e: Expression) = e match {
- case e @ (_: UIntLiteral | _: SIntLiteral | _: DoPrim ) => errors.append(new InvalidLOCException)
+ val moduleNames = (c.modules map (_.name)).toSet
+
+ def checkValidLoc(info: Info, mname: String, e: Expression) = e match {
+ case _: UIntLiteral | _: SIntLiteral | _: DoPrim =>
+ errors append new InvalidLOCException(info, mname)
case _ => // Do Nothing
}
- def checkChirrtlW(w: Width): Width = w match {
- case w: IntWidth if (w.width <= BigInt(0)) =>
- errors.append(new NegWidthException)
+
+ def checkChirrtlW(info: Info, mname: String)(w: Width): Width = w match {
+ case w: IntWidth if w.width <= 0 =>
+ errors append new NegWidthException(info, mname)
w
case _ => w
}
- def checkChirrtlT(t: Type): Type = {
- t map (checkChirrtlT) match {
- case t: VectorType if (t.size < 0) => errors.append(new NegVecSizeException)
+
+ def checkChirrtlT(info: Info, mname: String)(t: Type): Type = {
+ t map checkChirrtlT(info, mname) match {
+ case t: VectorType if t.size < 0 =>
+ errors append new NegVecSizeException(info, mname)
case _ => // Do nothing
}
- t map (checkChirrtlW)
+ t map checkChirrtlW(info, mname) map checkChirrtlT(info, mname)
}
- def checkChirrtlM(m: DefModule): DefModule = {
- val names = HashMap[String, Boolean]()
- val mnames = HashMap[String, Boolean]()
- def checkChirrtlE(e: Expression): Expression = {
- def validSubexp(e: Expression): Expression = e match {
- case (_:Reference|_:SubField|_:SubIndex|_:SubAccess|_:Mux|_:ValidIf) => e // No error
- case _ =>
- errors.append(new InvalidAccessException)
- e
- }
- e map (checkChirrtlE) match {
- case e: Reference if (!names.contains(e.name)) => errors.append(new UndeclaredReferenceException(e.name))
- case e: DoPrim => {}
- case (_:Mux|_:ValidIf) => {}
- case e: SubAccess =>
- validSubexp(e.expr)
- e
- case e: UIntLiteral => {}
- case e => e map (validSubexp)
- }
- e map (checkChirrtlW)
- e map (checkChirrtlT)
- e
- }
- def checkChirrtlS(s: Statement): Statement = {
- sinfo = get_info(s)
- def checkName(name: String): String = {
- if (names.contains(name)) errors.append(new NotUniqueException(name))
- else names(name) = true
- name
- }
-
- s map (checkName)
- s map (checkChirrtlT)
- s map (checkChirrtlE)
- s match {
- case s: DefMemory =>
- if (hasFlip(s.dataType)) errors.append(new MemWithFlipException(s.name))
- if (s.depth <= 0) errors.append(new NegMemSizeException)
- case s: DefInstance =>
- if (!c.modules.map(_.name).contains(s.module))
- errors.append(new ModuleNotDefinedException(s.module))
- case s: Connect => checkValidLoc(s.loc)
- case s: PartialConnect => checkValidLoc(s.loc)
- case s: Print => {}
- case _ => // Do Nothing
- }
-
- s map (checkChirrtlS)
+ def validSubexp(info: Info, mname: String)(e: Expression): Expression = {
+ e match {
+ case _: Reference | _: SubField | _: SubIndex | _: SubAccess |
+ _: Mux | _: ValidIf => // No error
+ case _ => errors append new InvalidAccessException(info, mname)
}
+ e
+ }
- mname = m.name
- for (m <- c.modules) {
- mnames(m.name) = true
- }
- for (p <- m.ports) {
- sinfo = p.info
- names(p.name) = true
- val tpe = p.tpe
- tpe map (checkChirrtlT)
- tpe map (checkChirrtlW)
+ def checkChirrtlE(info: Info, mname: String, names: HashSet[String])(e: Expression): Expression = {
+ e match {
+ case _: DoPrim | _:Mux | _:ValidIf | _: UIntLiteral =>
+ case e: Reference if !names(e.name) =>
+ errors append new UndeclaredReferenceException(info, mname, e.name)
+ case e: SubAccess => validSubexp(info, mname)(e.expr)
+ case e => e map validSubexp(info, mname)
}
+ (e map checkChirrtlW(info, mname)
+ map checkChirrtlT(info, mname)
+ map checkChirrtlE(info, mname, names))
+ }
+
+ def checkName(info: Info, mname: String, names: HashSet[String])(name: String): String = {
+ if (names(name))
+ errors append (new NotUniqueException(info, mname, name))
+ names += name
+ name
+ }
- m match {
- case m: Module => checkChirrtlS(m.body)
- case m: ExtModule => // Do Nothing
+ def checkChirrtlS(minfo: Info, mname: String, names: HashSet[String])(s: Statement): Statement = {
+ val info = get_info(s) match {case NoInfo => minfo case x => x}
+ (s map checkName(info, mname, names)) match {
+ case s: DefMemory =>
+ if (hasFlip(s.dataType)) errors append new MemWithFlipException(info, mname, s.name)
+ if (s.depth <= 0) errors append new NegMemSizeException(info, mname)
+ case s: DefInstance if !moduleNames(s.module) =>
+ errors append new ModuleNotDefinedException(info, mname, s.module)
+ case s: Connect => checkValidLoc(info, mname, s.loc)
+ case s: PartialConnect => checkValidLoc(info, mname, s.loc)
+ case _ => // Do Nothing
}
- m
+ (s map checkChirrtlT(info, mname)
+ map checkChirrtlE(info, mname, names)
+ map checkChirrtlS(info, mname, names))
+ }
+
+ def checkChirrtlP(mname: String, names: HashSet[String])(p: Port): Port = {
+ names += p.name
+ (p.tpe map checkChirrtlT(p.info, mname)
+ map checkChirrtlW(p.info, mname))
+ p
+ }
+
+ def checkChirrtlM(m: DefModule) {
+ val names = HashSet[String]()
+ (m map checkChirrtlP(m.name, names)
+ map checkChirrtlS(m.info, m.name, names))
}
- var numTopM = 0
- for (m <- c.modules) {
- if (m.name == c.main) numTopM = numTopM + 1
- checkChirrtlM(m)
+ c.modules foreach checkChirrtlM
+ (c.modules filter (_.name == c.main)).size match {
+ case 1 =>
+ case _ => errors append new NoTopModuleException(c.info, c.main)
}
- sinfo = c.info
- if (numTopM != 1) errors.append(new NoTopModuleException(c.main))
errors.trigger
c
}
}
-
-// vim: set ts=4 sw=4 et:
diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala
index 6d69b792..69629bf0 100644
--- a/src/main/scala/firrtl/passes/CheckInitialization.scala
+++ b/src/main/scala/firrtl/passes/CheckInitialization.scala
@@ -60,8 +60,7 @@ object CheckInitialization extends Pass {
}
def run(c: Circuit): Circuit = {
- val errors = collection.mutable.ArrayBuffer[PassException]()
-
+ val errors = new Errors()
def checkInitM(m: Module): Unit = {
val voidExprs = collection.mutable.HashMap[WrappedExpression, VoidExpr]()
@@ -69,19 +68,17 @@ object CheckInitialization extends Pass {
def hasVoidExpr(e: Expression): (Boolean, Seq[Expression]) = {
var void = false
val voidDeps = collection.mutable.ArrayBuffer[Expression]()
- def hasVoid(e: Expression): Expression = {
- e match {
- case e: WVoid =>
+ def hasVoid(e: Expression): Expression = e match {
+ case e: WVoid =>
+ void = true
+ e
+ case (_: WRef | _: WSubField) =>
+ if (voidExprs.contains(e)) {
void = true
- e
- case (_: WRef | _: WSubField) =>
- if (voidExprs.contains(e)) {
- void = true
- voidDeps += e
- }
- e
- case e => e map hasVoid
- }
+ voidDeps += e
+ }
+ e
+ case e => e map hasVoid
}
hasVoid(e)
(void, voidDeps)
@@ -110,19 +107,16 @@ object CheckInitialization extends Pass {
case node: DefNode => // Ignore nodes
case decl: IsDeclaration =>
val trace = getTrace(expr, voidExprs.toMap)
- errors += new RefNotInitializedException(decl.info, m.name, decl.name, trace)
+ errors append new RefNotInitializedException(decl.info, m.name, decl.name, trace)
}
}
}
- c.modules foreach { m =>
- m match {
- case m: Module => checkInitM(m)
- case m => // Do nothing
- }
+ c.modules foreach {
+ case m: Module => checkInitM(m)
+ case m => // Do nothing
}
-
- if (errors.nonEmpty) throw new PassExceptions(errors)
+ errors.trigger
c
}
}
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 6e49ce93..bba3efe7 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -29,746 +29,576 @@ package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
-// Datastructures
-import scala.collection.mutable.{HashMap,HashSet}
-import scala.collection.mutable.ArrayBuffer
-
import firrtl._
import firrtl.ir._
+import firrtl.PrimOps._
import firrtl.Utils._
import firrtl.Mappers._
-import firrtl.PrimOps._
import firrtl.WrappedType._
+// Datastructures
+import scala.collection.mutable.{HashMap, HashSet}
+
object CheckHighForm extends Pass with LazyLogging {
def name = "High Form Check"
// Custom Exceptions
- class NotUniqueException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} does not have a unique name.")
- class InvalidLOCException extends PassException(s"${sinfo}: [module ${mname}] Invalid connect to an expression that is not a reference or a WritePort.")
- class NegUIntException extends PassException(s"${sinfo}: [module ${mname}] UIntLiteral cannot be negative.")
- class UndeclaredReferenceException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} is not declared.")
- class PoisonWithFlipException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Poison ${name} cannot be a bundle type with flips.")
- class MemWithFlipException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Memory ${name} cannot be a bundle type with flips.")
- class InvalidAccessException extends PassException(s"${sinfo}: [module ${mname}] Invalid access to non-reference.")
- class NoTopModuleException(name: String) extends PassException(s"${sinfo}: A single module must be named ${name}.")
- class ModuleNotDefinedException(name: String) extends PassException(s"${sinfo}: Module ${name} is not defined.")
- class IncorrectNumArgsException(op: String, n: Int) extends PassException(s"${sinfo}: [module ${mname}] Primop ${op} requires ${n} expression arguments.")
- class IncorrectNumConstsException(op: String, n: Int) extends PassException(s"${sinfo}: [module ${mname}] Primop ${op} requires ${n} integer arguments.")
- class NegWidthException extends PassException(s"${sinfo}: [module ${mname}] Width cannot be negative or zero.")
- class NegVecSizeException extends PassException(s"${sinfo}: [module ${mname}] Vector type size cannot be negative.")
- class NegMemSizeException extends PassException(s"${sinfo}: [module ${mname}] Memory size cannot be negative or zero.")
- class BadPrintfException(x: Char) extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: " + "\"%" + x + "\"")
- class BadPrintfTrailingException extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: trailing " + "\"%\"")
- class BadPrintfIncorrectNumException extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: incorrect number of arguments")
- class InstanceLoop(loop: String) extends PassException(s"${sinfo}: [module ${mname}] Has instance loop $loop")
-
- /**
- * Maintains a one to many graph of each modules instantiated child module.
- * This graph can be searched for a path from a child module back to one of
- * it's parents. If one is found a recursive loop has happened
- * The graph is a map between the name of a node to set of names of that nodes children
- */
- class ModuleGraph {
- val nodes = new HashMap[String, HashSet[String]]
-
- /**
- * Add a child to a parent node
- * A parent node is created if it does not already exist
- *
- * @param parent module that instantiates another module
- * @param child module instantiated by parent
- * @return a list indicating a path from child to parent, empty if no such path
- */
- def add(parent: String, child: String): List[String] = {
- val childSet = nodes.getOrElseUpdate(parent, new HashSet[String])
- childSet += child
- pathExists(child, parent, List(child, parent))
- }
-
- /**
- * Starting at the name of a given child explore the tree of all children in depth first manner.
- * Return the first path (a list of strings) that goes from child to parent,
- * or an empty list of no such path is found.
- *
- * @param child starting name
- * @param parent name to find in children (recursively)
- * @param path
- * @return
- */
- def pathExists(child: String, parent: String, path: List[String] = Nil): List[String] = {
- nodes.get(child) match {
- case Some(children) =>
- if(children.contains(parent)) {
- parent :: path
- }
- else {
- children.foreach { grandchild =>
- val newPath = pathExists(grandchild, parent, grandchild :: path)
- if(newPath.nonEmpty) {
- return newPath
- }
- }
- Nil
- }
- case _ => Nil
- }
- }
- }
+ class NotUniqueException(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: [module ${mname}] Reference ${name} does not have a unique name.")
+ class InvalidLOCException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Invalid connect to an expression that is not a reference or a WritePort.")
+ class NegUIntException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] UIntLiteral cannot be negative.")
+ class UndeclaredReferenceException(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: [module ${mname}] Reference ${name} is not declared.")
+ class PoisonWithFlipException(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: [module ${mname}] Poison ${name} cannot be a bundle type with flips.")
+ class MemWithFlipException(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: [module ${mname}] Memory ${name} cannot be a bundle type with flips.")
+ class InvalidAccessException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Invalid access to non-reference.")
+ class ModuleNotDefinedException(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: Module ${name} is not defined.")
+ class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int) extends PassException(
+ s"${info}: [module ${mname}] Primop ${op} requires ${n} expression arguments.")
+ class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int) extends PassException(
+ s"${info}: [module ${mname}] Primop ${op} requires ${n} integer arguments.")
+ class NegWidthException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Width cannot be negative or zero.")
+ class NegVecSizeException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Vector type size cannot be negative.")
+ class NegMemSizeException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Memory size cannot be negative or zero.")
+ class BadPrintfException(info: Info, mname: String, x: Char) extends PassException(
+ s"${info}: [module ${mname}] Bad printf format: " + "\"%" + x + "\"")
+ class BadPrintfTrailingException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Bad printf format: trailing " + "\"%\"")
+ class BadPrintfIncorrectNumException(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Bad printf format: incorrect number of arguments")
+ class InstanceLoop(info: Info, mname: String, loop: String) extends PassException(
+ s"${info}: [module ${mname}] Has instance loop $loop")
+ class NoTopModuleException(info: Info, name: String) extends PassException(
+ s"${info}: A single module must be named ${name}.")
// TODO FIXME
// - Do we need to check for uniquness on port names?
- // Global Variables
- private var mname: String = ""
- private var sinfo: Info = NoInfo
- def run (c:Circuit): Circuit = {
+ def run(c: Circuit): Circuit = {
val errors = new Errors()
val moduleGraph = new ModuleGraph
+ val moduleNames = (c.modules map (_.name)).toSet
- def checkHighFormPrimop(e: DoPrim) = {
- def correctNum(ne: Option[Int], nc: Int) = {
+ def checkHighFormPrimop(info: Info, mname: String, e: DoPrim) {
+ def correctNum(ne: Option[Int], nc: Int) {
ne match {
- case Some(i) => if(e.args.length != i) errors.append(new IncorrectNumArgsException(e.op.toString, i))
- case None => // Do Nothing
+ case Some(i) if e.args.length != i =>
+ errors append (new IncorrectNumArgsException(info, mname, e.op.toString, i))
+ case _ => // Do Nothing
}
- if (e.consts.length != nc) errors.append(new IncorrectNumConstsException(e.op.toString, nc))
+ if (e.consts.length != nc)
+ errors append new IncorrectNumConstsException(info, mname, e.op.toString, nc)
}
e.op match {
- case Add => correctNum(Option(2),0)
- case Sub => correctNum(Option(2),0)
- case Mul => correctNum(Option(2),0)
- case Div => correctNum(Option(2),0)
- case Rem => correctNum(Option(2),0)
- case Lt => correctNum(Option(2),0)
- case Leq => correctNum(Option(2),0)
- case Gt => correctNum(Option(2),0)
- case Geq => correctNum(Option(2),0)
- case Eq => correctNum(Option(2),0)
- case Neq => correctNum(Option(2),0)
- case Pad => correctNum(Option(1),1)
- case AsUInt => correctNum(Option(1),0)
- case AsSInt => correctNum(Option(1),0)
- case AsClock => correctNum(Option(1),0)
- case Shl => correctNum(Option(1),1)
- case Shr => correctNum(Option(1),1)
- case Dshl => correctNum(Option(2),0)
- case Dshr => correctNum(Option(2),0)
- case Cvt => correctNum(Option(1),0)
- case Neg => correctNum(Option(1),0)
- case Not => correctNum(Option(1),0)
- case And => correctNum(Option(2),0)
- case Or => correctNum(Option(2),0)
- case Xor => correctNum(Option(2),0)
- case Andr => correctNum(None,0)
- case Orr => correctNum(None,0)
- case Xorr => correctNum(None,0)
- case Cat => correctNum(Option(2),0)
- case Bits => correctNum(Option(1),2)
- case Head => correctNum(Option(1),1)
- case Tail => correctNum(Option(1),1)
+ case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq |
+ Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat =>
+ correctNum(Option(2), 0)
+ case AsUInt | AsSInt | AsClock | Cvt | Neq | Not =>
+ correctNum(Option(1), 0)
+ case Pad | Shl | Shr | Head | Tail =>
+ correctNum(Option(1), 1)
+ case Bits =>
+ correctNum(Option(1), 2)
+ case Andr | Orr | Xorr =>
+ correctNum(None,0)
}
}
- def checkFstring(s: StringLit, i: Int) = {
+ def checkFstring(info: Info, mname: String, s: StringLit, i: Int) {
val validFormats = "bdxc"
- var percent = false
- var npercents = 0
- s.array.foreach { b =>
- if (percent) {
- if (validFormats.contains(b)) npercents += 1
- else if (b != '%') errors.append(new BadPrintfException(b.toChar))
- }
- percent = if (b == '%') !percent else false // %% -> percent = false
- }
- if (percent) errors.append(new BadPrintfTrailingException)
- if (npercents != i) errors.append(new BadPrintfIncorrectNumException)
+ val (percent, npercents) = (s.array foldLeft (false, 0)){
+ case ((percent, n), b) if percent && (validFormats contains b) =>
+ (false, n + 1)
+ case ((percent, n), b) if percent && b != '%' =>
+ errors append new BadPrintfException(info, mname, b.toChar)
+ (false, n)
+ case ((percent, n), b) =>
+ (if (b == '%') !percent else false /* %% -> percent = false */, n)
+ }
+ if (percent) errors append new BadPrintfTrailingException(info, mname)
+ if (npercents != i) errors append new BadPrintfIncorrectNumException(info, mname)
}
- def checkValidLoc(e: Expression) = {
- e match {
- case e @ (_: UIntLiteral | _: SIntLiteral | _: DoPrim ) => errors.append(new InvalidLOCException)
- case _ => // Do Nothing
- }
+
+ def checkValidLoc(info: Info, mname: String, e: Expression) = e match {
+ case _: UIntLiteral | _: SIntLiteral | _: DoPrim =>
+ errors append new InvalidLOCException(info, mname)
+ case _ => // Do Nothing
}
- def checkHighFormW(w: Width): Width = {
+
+ def checkHighFormW(info: Info, mname: String)(w: Width): Width = {
w match {
- case w: IntWidth =>
- if (w.width <= BigInt(0)) errors.append(new NegWidthException)
- case _ => // Do Nothing
+ case w: IntWidth if w.width <= 0 =>
+ errors append new NegWidthException(info, mname)
+ case w => // Do nothing
}
w
}
- def checkHighFormT(t: Type): Type = {
- t map (checkHighFormT) match {
- case t: VectorType =>
- if (t.size < 0) errors.append(new NegVecSizeException)
+
+ def checkHighFormT(info: Info, mname: String)(t: Type): Type = {
+ t match {
+ case t: VectorType if t.size < 0 =>
+ errors append new NegVecSizeException(info, mname)
case _ => // Do nothing
}
- t map (checkHighFormW)
+ t map checkHighFormW(info, mname) map checkHighFormT(info, mname)
}
- def checkHighFormM(m: DefModule): DefModule = {
- val names = HashMap[String, Boolean]()
- val mnames = HashMap[String, Boolean]()
- def checkHighFormE(e: Expression): Expression = {
- def validSubexp(e: Expression): Expression = {
- e match {
- case (_:WRef|_:WSubField|_:WSubIndex|_:WSubAccess|_:Mux|_:ValidIf) => {} // No error
- case _ => errors.append(new InvalidAccessException)
- }
- e
- }
- e map (checkHighFormE) match {
- case e: WRef =>
- if (!names.contains(e.name)) errors.append(new UndeclaredReferenceException(e.name))
- case e: DoPrim => checkHighFormPrimop(e)
- case (_:Mux|_:ValidIf) => {}
- case e: WSubAccess => {
- validSubexp(e.exp)
- e
- }
- case e: UIntLiteral =>
- if (e.value < 0) errors.append(new NegUIntException)
- case e => e map (validSubexp)
- }
- e map (checkHighFormW)
- e map (checkHighFormT)
- e
+ def validSubexp(info: Info, mname: String)(e: Expression): Expression = {
+ e match {
+ case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error
+ case _ => errors append new InvalidAccessException(info, mname)
}
- def checkHighFormS(s: Statement): Statement = {
- def checkName(name: String): String = {
- if (names.contains(name)) errors.append(new NotUniqueException(name))
- else names(name) = true
- name
- }
- sinfo = get_info(s)
-
- s map (checkName)
- s map (checkHighFormT)
- s map (checkHighFormE)
- s match {
- case s: DefMemory => {
- if (hasFlip(s.dataType)) errors.append(new MemWithFlipException(s.name))
- if (s.depth <= 0) errors.append(new NegMemSizeException)
- }
- case s: WDefInstance => {
- if (!c.modules.map(_.name).contains(s.module))
- errors.append(new ModuleNotDefinedException(s.module))
- // Check to see if a recursive module instantiation has occured
- val childToParent = moduleGraph.add(m.name, s.module)
- if(childToParent.nonEmpty) {
- errors.append(new InstanceLoop(childToParent.mkString("->")))
- }
- }
- case s: Connect => checkValidLoc(s.loc)
- case s: PartialConnect => checkValidLoc(s.loc)
- case s: Print => checkFstring(s.string, s.args.length)
- case _ => // Do Nothing
- }
+ e
+ }
- s map (checkHighFormS)
- }
+ def checkHighFormE(info: Info, mname: String, names: HashSet[String])(e: Expression): Expression = {
+ e match {
+ case e: WRef if !names(e.name) =>
+ errors append new UndeclaredReferenceException(info, mname, e.name)
+ case e: UIntLiteral if e.value < 0 =>
+ errors append new NegUIntException(info, mname)
+ case e: DoPrim => checkHighFormPrimop(info, mname, e)
+ case _: WRef | _: UIntLiteral | _: Mux | _: ValidIf =>
+ case e: WSubAccess => validSubexp(info, mname)(e.exp)
+ case e => e map validSubexp(info, mname)
+ }
+ (e map checkHighFormW(info, mname)
+ map checkHighFormT(info, mname)
+ map checkHighFormE(info, mname, names))
+ }
- mname = m.name
- for (m <- c.modules) {
- mnames(m.name) = true
- }
- for (p <- m.ports) {
- // FIXME should we set sinfo here?
- names(p.name) = true
- val tpe = p.tpe
- tpe map (checkHighFormT)
- tpe map (checkHighFormW)
- }
+ def checkName(info: Info, mname: String, names: HashSet[String])(name: String): String = {
+ if (names(name))
+ errors append new NotUniqueException(info, mname, name)
+ names += name
+ name
+ }
- m match {
- case m: Module => checkHighFormS(m.body)
- case m: ExtModule => // Do Nothing
- }
- m
+ def checkHighFormS(minfo: Info, mname: String, names: HashSet[String])(s: Statement): Statement = {
+ val info = get_info(s) match {case NoInfo => minfo case x => x}
+ (s map checkName(info, mname, names)) match {
+ case s: DefMemory =>
+ if (hasFlip(s.dataType))
+ errors append new MemWithFlipException(info, mname, s.name)
+ if (s.depth <= 0)
+ errors append new NegMemSizeException(info, mname)
+ case s: WDefInstance =>
+ if (!moduleNames(s.module))
+ errors append new ModuleNotDefinedException(info, mname, s.module)
+ // Check to see if a recursive module instantiation has occured
+ val childToParent = moduleGraph add (mname, s.module)
+ if (childToParent.nonEmpty)
+ errors append new InstanceLoop(info, mname, childToParent mkString "->")
+ case s: Connect => checkValidLoc(info, mname, s.loc)
+ case s: PartialConnect => checkValidLoc(info, mname, s.loc)
+ case s: Print => checkFstring(info, mname, s.string, s.args.length)
+ case s => // Do Nothing
+ }
+ (s map checkHighFormT(info, mname)
+ map checkHighFormE(info, mname, names)
+ map checkHighFormS(minfo, mname, names))
+ }
+
+ def checkHighFormP(mname: String, names: HashSet[String])(p: Port): Port = {
+ names += p.name
+ (p.tpe map checkHighFormT(p.info, mname)
+ map checkHighFormW(p.info, mname))
+ p
+ }
+
+ def checkHighFormM(m: DefModule) {
+ val names = HashSet[String]()
+ (m map checkHighFormP(m.name, names)
+ map checkHighFormS(m.info, m.name, names))
}
- var numTopM = 0
- for (m <- c.modules) {
- if (m.name == c.main) numTopM = numTopM + 1
- checkHighFormM(m)
+ c.modules foreach checkHighFormM
+ (c.modules filter (_.name == c.main)).size match {
+ case 1 =>
+ case _ => errors append new NoTopModuleException(c.info, c.main)
}
- sinfo = c.info
- if (numTopM != 1) errors.append(new NoTopModuleException(c.main))
errors.trigger
c
}
}
object CheckTypes extends Pass with LazyLogging {
- def name = "Check Types"
- var mname = ""
+ def name = "Check Types"
// Custom Exceptions
- class SubfieldNotInBundle(info:Info, name:String) extends PassException(s"${info}: [module ${mname} ] Subfield ${name} is not in bundle.")
- class SubfieldOnNonBundle(info:Info, name:String) extends PassException(s"${info}: [module ${mname}] Subfield ${name} is accessed on a non-bundle.")
- class IndexTooLarge(info:Info, value:Int) extends PassException(s"${info}: [module ${mname}] Index with value ${value} is too large.")
- class IndexOnNonVector(info:Info) extends PassException(s"${info}: [module ${mname}] Index illegal on non-vector type.")
- class AccessIndexNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Access index must be a UInt type.")
- class IndexNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Index is not of UIntType.")
- class EnableNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Enable is not of UIntType.")
- class InvalidConnect(info:Info, lhs:String, rhs:String) extends PassException(s"${info}: [module ${mname}] Type mismatch. Cannot connect ${lhs} to ${rhs}.")
- class InvalidRegInit(info:Info) extends PassException(s"${info}: [module ${mname}] Type of init must match type of DefRegister.")
- class PrintfArgNotGround(info:Info) extends PassException(s"${info}: [module ${mname}] Printf arguments must be either UIntType or SIntType.")
- class ReqClk(info:Info) extends PassException(s"${info}: [module ${mname}] Requires a clock typed signal.")
- class EnNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Enable must be a UIntType typed signal.")
- class PredNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Predicate not a UIntType.")
- class OpNotGround(info:Info, op:String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on non-ground types.")
- class OpNotUInt(info:Info, op:String,e:String) extends PassException(s"${info}: [module ${mname}] Primop ${op} requires argument ${e} to be a UInt type.")
- class OpNotAllUInt(info:Info, op:String) extends PassException(s"${info}: [module ${mname}] Primop ${op} requires all arguments to be UInt type.")
- class OpNotAllSameType(info:Info, op:String) extends PassException(s"${info}: [module ${mname}] Primop ${op} requires all operands to have the same type.")
- class NodePassiveType(info:Info) extends PassException(s"${info}: [module ${mname}] Node must be a passive type.")
- class MuxSameType(info:Info) extends PassException(s"${info}: [module ${mname}] Must mux between equivalent types.")
- class MuxPassiveTypes(info:Info) extends PassException(s"${info}: [module ${mname}] Must mux between passive types.")
- class MuxCondUInt(info:Info) extends PassException(s"${info}: [module ${mname}] A mux condition must be of type UInt.")
- class ValidIfPassiveTypes(info:Info) extends PassException(s"${info}: [module ${mname}] Must validif a passive type.")
- class ValidIfCondUInt(info:Info) extends PassException(s"${info}: [module ${mname}] A validif condition must be of type UInt.")
- //;---------------- Helper Functions --------------
- def ut () : UIntType = UIntType(UnknownWidth)
- def st () : SIntType = SIntType(UnknownWidth)
+ class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: [module ${mname} ] Subfield ${name} is not in bundle.")
+ class SubfieldOnNonBundle(info: Info, mname: String, name: String) extends PassException(
+ s"${info}: [module ${mname}] Subfield ${name} is accessed on a non-bundle.")
+ class IndexTooLarge(info: Info, mname: String, value: Int) extends PassException(
+ s"${info}: [module ${mname}] Index with value ${value} is too large.")
+ class IndexOnNonVector(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Index illegal on non-vector type.")
+ class AccessIndexNotUInt(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Access index must be a UInt type.")
+ class IndexNotUInt(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Index is not of UIntType.")
+ class EnableNotUInt(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Enable is not of UIntType.")
+ class InvalidConnect(info: Info, mname: String, lhs: String, rhs: String) extends PassException(
+ s"${info}: [module ${mname}] Type mismatch. Cannot connect ${lhs} to ${rhs}.")
+ class InvalidRegInit(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Type of init must match type of DefRegister.")
+ class PrintfArgNotGround(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Printf arguments must be either UIntType or SIntType.")
+ class ReqClk(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Requires a clock typed signal.")
+ class EnNotUInt(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Enable must be a UIntType typed signal.")
+ class PredNotUInt(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Predicate not a UIntType.")
+ class OpNotGround(info: Info, mname: String, op: String) extends PassException(
+ s"${info}: [module ${mname}] Primop ${op} cannot operate on non-ground types.")
+ class OpNotUInt(info: Info, mname: String, op: String, e: String) extends PassException(
+ s"${info}: [module ${mname}] Primop ${op} requires argument ${e} to be a UInt type.")
+ class OpNotAllUInt(info: Info, mname: String, op: String) extends PassException(
+ s"${info}: [module ${mname}] Primop ${op} requires all arguments to be UInt type.")
+ class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException(
+ s"${info}: [module ${mname}] Primop ${op} requires all operands to have the same type.")
+ class NodePassiveType(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Node must be a passive type.")
+ class MuxSameType(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Must mux between equivalent types.")
+ class MuxPassiveTypes(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Must mux between passive types.")
+ class MuxCondUInt(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] A mux condition must be of type UInt.")
+ class ValidIfPassiveTypes(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Must validif a passive type.")
+ class ValidIfCondUInt(info: Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] A validif condition must be of type UInt.")
+
+ //;---------------- Helper Functions --------------
+ def ut: UIntType = UIntType(UnknownWidth)
+ def st: SIntType = SIntType(UnknownWidth)
- def check_types_primop (e:DoPrim, errors:Errors, info:Info) : Unit = {
- def all_same_type (ls:Seq[Expression]) : Unit = {
- var error = false
- for (x <- ls) {
- if (wt(ls.head.tpe) != wt(x.tpe)) error = true
- }
- if (error) errors.append(new OpNotAllSameType(info,e.op.serialize))
- }
- def all_ground (ls:Seq[Expression]) : Unit = {
- var error = false
- for (x <- ls ) {
- x.tpe match {
- case _: UIntType | _: SIntType =>
- case _ => error = true
- }
- }
- if (error) errors.append(new OpNotGround(info,e.op.serialize))
- }
- def all_uint (ls:Seq[Expression]) : Unit = {
- var error = false
- for (x <- ls ) {
- x.tpe match {
- case _: UIntType =>
- case _ => error = true
- }
- }
- if (error) errors.append(new OpNotAllUInt(info,e.op.serialize))
- }
- def is_uint (x:Expression) : Unit = {
- var error = false
- x.tpe match {
- case _: UIntType =>
- case _ => error = true
- }
- if (error) errors.append(new OpNotUInt(info,e.op.serialize,x.serialize))
+ def run (c:Circuit) : Circuit = {
+ val errors = new Errors()
+
+ def passive(t: Type): Boolean = t match {
+ case (_: UIntType |_: SIntType) => true
+ case (t: VectorType) => passive(t.tpe)
+ case (t: BundleType) => t.fields forall (x => x.flip == Default && passive(x.tpe))
+ case (t) => true
+ }
+
+ def check_types_primop(info: Info, mname: String, e: DoPrim) {
+ def all_same_type (ls:Seq[Expression]) {
+ if (ls exists (x => wt(ls.head.tpe) != wt(e.tpe)))
+ errors append new OpNotAllSameType(info, mname, e.op.serialize)
+ }
+ def all_ground (ls: Seq[Expression]) {
+ if (ls exists (x => x.tpe match {
+ case _: UIntType | _: SIntType => false
+ case _ => true
+ })) errors append new OpNotGround(info, mname, e.op.serialize)
+ }
+ def all_uint (ls: Seq[Expression]) {
+ if (ls exists (x => x.tpe match {
+ case _: UIntType => false
+ case _ => true
+ })) errors append new OpNotAllUInt(info, mname, e.op.serialize)
+ }
+ def is_uint (x:Expression) {
+ if (x.tpe match {
+ case _: UIntType => false
+ case _ => true
+ }) errors append new OpNotUInt(info, mname, e.op.serialize, x.serialize)
}
-
+
e.op match {
- case AsUInt =>
- case AsSInt =>
- case AsClock =>
- case Dshl => is_uint(e.args(1)); all_ground(e.args)
- case Dshr => is_uint(e.args(1)); all_ground(e.args)
- case Add => all_ground(e.args)
- case Sub => all_ground(e.args)
- case Mul => all_ground(e.args)
- case Div => all_ground(e.args)
- case Rem => all_ground(e.args)
- case Lt => all_ground(e.args)
- case Leq => all_ground(e.args)
- case Gt => all_ground(e.args)
- case Geq => all_ground(e.args)
- case Eq => all_ground(e.args)
- case Neq => all_ground(e.args)
- case Pad => all_ground(e.args)
- case Shl => all_ground(e.args)
- case Shr => all_ground(e.args)
- case Cvt => all_ground(e.args)
- case Neg => all_ground(e.args)
- case Not => all_ground(e.args)
- case And => all_ground(e.args)
- case Or => all_ground(e.args)
- case Xor => all_ground(e.args)
- case Andr => all_ground(e.args)
- case Orr => all_ground(e.args)
- case Xorr => all_ground(e.args)
- case Cat => all_ground(e.args)
- case Bits => all_ground(e.args)
- case Head => all_ground(e.args)
- case Tail => all_ground(e.args)
- }
- }
-
- def run (c:Circuit) : Circuit = {
- val errors = new Errors()
- def passive (t:Type) : Boolean = {
- (t) match {
- case (_:UIntType|_:SIntType) => true
- case (t:VectorType) => passive(t.tpe)
- case (t:BundleType) => {
- var p = true
- for (x <- t.fields ) {
- if (x.flip == Flip) p = false
- if (!passive(x.tpe)) p = false
- }
- p
- }
- case (t) => true
- }
- }
- def check_types_e (info:Info)(e:Expression) : Expression = {
- (e map (check_types_e(info))) match {
- case (e:WRef) => e
- case (e:WSubField) => {
- (e.exp.tpe) match {
- case (t:BundleType) => {
- val ft = t.fields.find(p => p.name == e.name)
- if (ft == None) errors.append(new SubfieldNotInBundle(info,e.name))
- }
- case (t) => errors.append(new SubfieldOnNonBundle(info,e.name))
- }
- }
- case (e:WSubIndex) => {
- (e.exp.tpe) match {
- case (t:VectorType) => {
- if (e.value >= t.size) errors.append(new IndexTooLarge(info,e.value))
- }
- case (t) => errors.append(new IndexOnNonVector(info))
- }
- }
- case (e:WSubAccess) => {
- (e.exp.tpe) match {
- case (t:VectorType) => false
- case (t) => errors.append(new IndexOnNonVector(info))
- }
- (e.index.tpe) match {
- case (t:UIntType) => false
- case (t) => errors.append(new AccessIndexNotUInt(info))
- }
- }
- case (e:DoPrim) => check_types_primop(e,errors,info)
- case (e:Mux) => {
- if (wt(e.tval.tpe) != wt(e.fval.tpe)) errors.append(new MuxSameType(info))
- if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info))
- e.cond.tpe match {
- case _: UIntType =>
- case _ => errors.append(new MuxCondUInt(info))
- }
- }
- case (e:ValidIf) => {
- if (!passive(e.tpe)) errors.append(new ValidIfPassiveTypes(info))
- e.cond.tpe match {
- case _: UIntType =>
- case _ => errors.append(new ValidIfCondUInt(info))
- }
- }
- case (_:UIntLiteral | _:SIntLiteral) => false
- }
- e
+ case AsUInt | AsSInt | AsClock =>
+ case Dshl => is_uint(e.args(1)); all_ground(e.args)
+ case Dshr => is_uint(e.args(1)); all_ground(e.args)
+ case _ => all_ground(e.args)
}
-
- def bulk_equals (t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = {
- //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2])
- (t1,t2) match {
- case (ClockType, ClockType) => flip1 == flip2
- case (t1:UIntType,t2:UIntType) => flip1 == flip2
- case (t1:SIntType,t2:SIntType) => flip1 == flip2
- case (t1:BundleType,t2:BundleType) => {
- var isEqual = true
- for (i <- 0 until t1.fields.size) {
- for (j <- 0 until t2.fields.size) {
- val f1 = t1.fields(i)
- val f2 = t2.fields(j)
- if (f1.name == f2.name) {
- val field_equal = bulk_equals(f1.tpe,f2.tpe,times(flip1, f1.flip),times(flip2, f2.flip))
- if (!field_equal) isEqual = false
- }
- }
- }
- isEqual
- }
- case (t1:VectorType,t2:VectorType) => bulk_equals(t1.tpe,t2.tpe,flip1,flip2)
- }
+ }
+
+ def check_types_e(info:Info, mname: String)(e: Expression): Expression = {
+ e match {
+ case (e: WSubField) => e.exp.tpe match {
+ case (t: BundleType) => t.fields find (_.name == e.name) match {
+ case Some(_) =>
+ case None => errors append new SubfieldNotInBundle(info, mname, e.name)
+ }
+ case _ => errors append new SubfieldOnNonBundle(info, mname, e.name)
+ }
+ case (e: WSubIndex) => e.exp.tpe match {
+ case (t: VectorType) if e.value < t.size =>
+ case (t: VectorType) =>
+ errors append (new IndexTooLarge(info, mname, e.value))
+ case _ =>
+ errors append (new IndexOnNonVector(info, mname))
+ }
+ case (e: WSubAccess) =>
+ e.exp.tpe match {
+ case _: VectorType =>
+ case _ => errors append new IndexOnNonVector(info, mname)
+ }
+ e.index.tpe match {
+ case _: UIntType =>
+ case _ => errors append new AccessIndexNotUInt(info, mname)
+ }
+ case (e: DoPrim) => check_types_primop(info, mname, e)
+ case (e: Mux) =>
+ if (wt(e.tval.tpe) != wt(e.fval.tpe))
+ errors append new MuxSameType(info, mname)
+ if (!passive(e.tpe))
+ errors append new MuxPassiveTypes(info, mname)
+ e.cond.tpe match {
+ case _: UIntType =>
+ case _ => errors append new MuxCondUInt(info, mname)
+ }
+ case (e: ValidIf) =>
+ if (!passive(e.tpe))
+ errors append new ValidIfPassiveTypes(info, mname)
+ e.cond.tpe match {
+ case _: UIntType =>
+ case _ => errors append new ValidIfCondUInt(info, mname)
+ }
+ case _ =>
}
+ e map check_types_e(info, mname)
+ }
- def check_types_s (s:Statement) : Statement = {
- s map (check_types_e(get_info(s))) match {
- case (s:Connect) => if (wt(s.loc.tpe) != wt(s.expr.tpe)) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
- case (s:DefRegister) => if (wt(s.tpe) != wt(s.init.tpe)) errors.append(new InvalidRegInit(s.info))
- case (s:PartialConnect) => if (!bulk_equals(s.loc.tpe,s.expr.tpe,Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
- case (s:Stop) => {
- if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info))
- if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
- }
- case (s:Print)=> {
- for (x <- s.args ) {
- if (wt(x.tpe) != wt(ut()) && wt(x.tpe) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info))
- }
- if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info))
- if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
+ def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = {
+ //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2])
+ (t1, t2) match {
+ case (ClockType, ClockType) => flip1 == flip2
+ case (_: UIntType, _: UIntType) => flip1 == flip2
+ case (_: SIntType, _: SIntType) => flip1 == flip2
+ case (t1: BundleType, t2: BundleType) =>
+ val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())(
+ (map, f1) => map + (f1.name -> (f1.tpe, f1.flip)))
+ t2.fields forall (f2 =>
+ t1_fields get f2.name match {
+ case None => true
+ case Some((f1_tpe, f1_flip)) =>
+ bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip))
}
- case (s:Conditionally) => if (wt(s.pred.tpe) != wt(ut()) ) errors.append(new PredNotUInt(s.info))
- case (s:DefNode) => if (!passive(s.value.tpe) ) errors.append(new NodePassiveType(s.info))
- case (s) => false
- }
- s map (check_types_s)
+ )
+ case (t1: VectorType, t2: VectorType) =>
+ bulk_equals(t1.tpe, t2.tpe, flip1, flip2)
}
-
- for (m <- c.modules ) {
- mname = m.name
- (m) match {
- case (m:ExtModule) => false
- case (m:Module) => check_types_s(m.body)
- }
- }
- errors.trigger
- c
- }
+ }
+
+ def check_types_s(minfo: Info, mname: String)(s: Statement): Statement = {
+ val info = get_info(s) match { case NoInfo => minfo case x => x }
+ s match {
+ case (s: Connect) if wt(s.loc.tpe) != wt(s.expr.tpe) =>
+ errors append new InvalidConnect(info, mname, s.loc.serialize, s.expr.serialize)
+ case (s:PartialConnect) if !bulk_equals(s.loc.tpe, s.expr.tpe, Default, Default) =>
+ errors append new InvalidConnect(info, mname, s.loc.serialize, s.expr.serialize)
+ case (s: DefRegister) if wt(s.tpe) != wt(s.init.tpe) =>
+ errors append new InvalidRegInit(info, mname)
+ case (s: Conditionally) if wt(s.pred.tpe) != wt(ut) =>
+ errors append new PredNotUInt(info, mname)
+ case (s: DefNode) if !passive(s.value.tpe) =>
+ errors append new NodePassiveType(info, mname)
+ case (s: Stop) =>
+ if (wt(s.clk.tpe) != wt(ClockType)) errors append new ReqClk(info, mname)
+ if (wt(s.en.tpe) != wt(ut)) errors append new EnNotUInt(info, mname)
+ case (s: Print) =>
+ if (s.args exists (x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st)))
+ errors append new PrintfArgNotGround(info, mname)
+ if (wt(s.clk.tpe) != wt(ClockType)) errors append new ReqClk(info, mname)
+ if (wt(s.en.tpe) != wt(ut)) errors append new EnNotUInt(info, mname)
+ case _ =>
+ }
+ s map check_types_e(info, mname) map check_types_s(info, mname)
+ }
+
+ c.modules foreach (m => m map check_types_s(m.info, m.name))
+ errors.trigger
+ c
+ }
}
object CheckGenders extends Pass {
- def name = "Check Genders"
- var mname = ""
- class WrongGender (info:Info,expr:String,wrong:String,right:String) extends PassException(s"${info}: [module ${mname}] Expression ${expr} is used as a ${wrong} but can only be used as a ${right}.")
+ def name = "Check Genders"
+
+ implicit def toStr(g: Gender): String = g match {
+ case MALE => "source"
+ case FEMALE => "sink"
+ case UNKNOWNGENDER => "unknown"
+ case BIGENDER => "sourceOrSink"
+ }
- def dir_to_gender (d:Direction) : Gender = {
- d match {
- case Input => MALE
- case Output => FEMALE //BI-GENDER
- }
- }
+ class WrongGender(info:Info, mname: String, expr: String, wrong: Gender, right: Gender) extends PassException(
+ s"${info}: [module ${mname}] Expression ${expr} is used as a ${wrong} but can only be used as a ${right}.")
- def as_srcsnk (g:Gender) : String = {
- g match {
- case MALE => "source"
- case FEMALE => "sink"
- case UNKNOWNGENDER => "unknown"
- case BIGENDER => "sourceOrSink"
+ def run (c:Circuit): Circuit = {
+ val errors = new Errors()
+
+ def get_gender(e: Expression, genders: HashMap[String, Gender]): Gender = e match {
+ case (e: WRef) => genders(e.name)
+ case (e: WSubIndex) => get_gender(e.exp, genders)
+ case (e: WSubAccess) => get_gender(e.exp, genders)
+ case (e: WSubField) => e.exp.tpe match {case t: BundleType =>
+ val f = (t.fields find (_.name == e.name)).get
+ times(get_gender(e.exp, genders), f.flip)
+ }
+ case _ => MALE
+ }
+
+ def flip_q(t: Type): Boolean = {
+ def flip_rec(t: Type, f: Orientation): Boolean = t match {
+ case (t:BundleType) => t.fields exists (
+ field => flip_rec(field.tpe, times(f, field.flip))
+ )
+ case t: VectorType => flip_rec(t.tpe, f)
+ case t => f == Flip
+ }
+ flip_rec(t, Default)
+ }
+
+ def check_gender(info:Info, mname: String,
+ genders: HashMap[String,Gender], desired: Gender)(e:Expression): Expression = {
+ val gender = get_gender(e,genders)
+ (gender, desired) match {
+ case (MALE, FEMALE) =>
+ errors append new WrongGender(info, mname, e.serialize, desired, gender)
+ case (FEMALE, MALE) => kind(e) match {
+ case _: PortKind | _: InstanceKind if !flip_q(e.tpe) => // OK!
+ case _ =>
+ errors append new WrongGender(info, mname, e.serialize, desired, gender)
+ }
+ case _ =>
}
+ e
}
- def run (c:Circuit): Circuit = {
- val errors = new Errors()
- def get_kind (e:Expression) : Kind = {
- (e) match {
- case (e:WRef) => e.kind
- case (e:WSubField) => get_kind(e.exp)
- case (e:WSubIndex) => get_kind(e.exp)
- case (e:WSubAccess) => get_kind(e.exp)
- case (e) => NodeKind()
- }
- }
-
- def check_gender (info:Info,genders:HashMap[String,Gender],desired:Gender)(e:Expression) : Expression = {
- val gender = get_gender(e,genders)
- val kindx = get_kind(e)
- def flipQ (t:Type) : Boolean = {
- var fQ = false
- def flip_rec (t:Type,f:Orientation) : Type = {
- (t) match {
- case (t:BundleType) => {
- for (field <- t.fields) {
- flip_rec(field.tpe,times(f, field.flip))
- }
- }
- case (t:VectorType) => flip_rec(t.tpe,f)
- case (t) => if (f == Flip) fQ = true
- }
- t
- }
- flip_rec(t,Default)
- fQ
- }
-
- val has_flipQ = flipQ(e.tpe)
- //println(e)
- //println(gender)
- //println(desired)
- //println(kindx)
- //println(desired == gender)
- //if gender != desired and gender != BI-GENDER:
- (gender,desired) match {
- case (MALE, FEMALE) => errors.append(new WrongGender(info,e.serialize,as_srcsnk(desired),as_srcsnk(gender)))
- case (FEMALE, MALE) =>
- if ((kindx == PortKind() || kindx == InstanceKind()) && has_flipQ == false) {
- //; OK!
- false
- } else {
- //; Not Ok!
- errors.append(new WrongGender(info,e.serialize,as_srcsnk(desired),as_srcsnk(gender)))
- }
- case _ => false
- }
- e
- }
-
- def get_gender (e:Expression,genders:HashMap[String,Gender]) : Gender = {
- (e) match {
- case (e:WRef) => genders(e.name)
- case (e:WSubField) =>
- val f = e.exp.tpe.asInstanceOf[BundleType].fields.find(f => f.name == e.name).get
- times(get_gender(e.exp,genders),f.flip)
- case (e:WSubIndex) => get_gender(e.exp,genders)
- case (e:WSubAccess) => get_gender(e.exp,genders)
- case (e:DoPrim) => MALE
- case (e:UIntLiteral) => MALE
- case (e:SIntLiteral) => MALE
- case (e:Mux) => MALE
- case (e:ValidIf) => MALE
- }
- }
-
- def check_genders_e (info:Info,genders:HashMap[String,Gender])(e:Expression) : Expression = {
- e map (check_genders_e(info,genders))
- (e) match {
- case (e:WRef) => false
- case (e:WSubField) => false
- case (e:WSubIndex) => false
- case (e:WSubAccess) => false
- case (e:DoPrim) => for (e <- e.args ) { check_gender(info,genders,MALE)(e) }
- case (e:Mux) => e map (check_gender(info,genders,MALE))
- case (e:ValidIf) => e map (check_gender(info,genders,MALE))
- case (e:UIntLiteral) => false
- case (e:SIntLiteral) => false
- }
- e
+ def check_genders_e (info:Info, mname: String, genders: HashMap[String,Gender])(e:Expression): Expression = {
+ e match {
+ case e: Mux => e map check_gender(info, mname, genders, MALE)
+ case e: DoPrim => e.args map check_gender(info, mname, genders, MALE)
+ case _ =>
}
+ e map check_genders_e(info, mname, genders)
+ }
- def check_genders_s (genders:HashMap[String,Gender])(s:Statement) : Statement = {
- s map (check_genders_e(get_info(s),genders))
- s map (check_genders_s(genders))
- (s) match {
- case (s:DefWire) => genders(s.name) = BIGENDER
- case (s:DefRegister) => genders(s.name) = BIGENDER
- case (s:DefNode) => {
- check_gender(s.info,genders,MALE)(s.value)
- genders(s.name) = MALE
- }
- case (s:DefMemory) => genders(s.name) = MALE
- case (s:WDefInstance) => genders(s.name) = MALE
- case (s:Connect) => {
- check_gender(s.info,genders,FEMALE)(s.loc)
- check_gender(s.info,genders,MALE)(s.expr)
- }
- case (s:Print) => {
- for (x <- s.args ) {
- check_gender(s.info,genders,MALE)(x)
- }
- check_gender(s.info,genders,MALE)(s.en)
- check_gender(s.info,genders,MALE)(s.clk)
- }
- case (s:PartialConnect) => {
- check_gender(s.info,genders,FEMALE)(s.loc)
- check_gender(s.info,genders,MALE)(s.expr)
- }
- case (s:Conditionally) => {
- check_gender(s.info,genders,MALE)(s.pred)
- }
- case EmptyStmt => false
- case (s:Stop) => {
- check_gender(s.info,genders,MALE)(s.en)
- check_gender(s.info,genders,MALE)(s.clk)
- }
- case (_:Block | _:IsInvalid) => false
- }
- s
- }
+ def check_genders_s(minfo: Info, mname: String, genders: HashMap[String,Gender])(s: Statement): Statement = {
+ val info = get_info(s) match { case NoInfo => minfo case x => x }
+ s match {
+ case (s: DefWire) => genders(s.name) = BIGENDER
+ case (s: DefRegister) => genders(s.name) = BIGENDER
+ case (s: DefMemory) => genders(s.name) = MALE
+ case (s: WDefInstance) => genders(s.name) = MALE
+ case (s: DefNode) =>
+ check_gender(info, mname, genders, MALE)(s.value)
+ genders(s.name) = MALE
+ case (s: Connect) =>
+ check_gender(info, mname, genders, FEMALE)(s.loc)
+ check_gender(info, mname, genders, MALE)(s.expr)
+ case (s: Print) =>
+ s.args map check_gender(info, mname, genders, MALE)
+ check_gender(info, mname, genders, MALE)(s.en)
+ check_gender(info, mname, genders, MALE)(s.clk)
+ case (s: PartialConnect) =>
+ check_gender(info, mname, genders, FEMALE)(s.loc)
+ check_gender(info, mname, genders, MALE)(s.expr)
+ case (s: Conditionally) =>
+ check_gender(info, mname, genders, MALE)(s.pred)
+ case (s: Stop) =>
+ check_gender(info, mname, genders, MALE)(s.en)
+ check_gender(info, mname, genders, MALE)(s.clk)
+ case _ =>
+ }
+ s map check_genders_e(info, mname, genders) map check_genders_s(minfo, mname, genders)
+ }
- for (m <- c.modules ) {
- mname = m.name
- val genders = HashMap[String,Gender]()
- for (p <- m.ports) {
- genders(p.name) = dir_to_gender(p.direction)
- }
- (m) match {
- case (m:ExtModule) => false
- case (m:Module) => check_genders_s(genders)(m.body)
- }
- }
- errors.trigger
- c
- }
+ for (m <- c.modules) {
+ val genders = HashMap[String, Gender]()
+ genders ++= (m.ports map (p => p.name -> to_gender(p.direction)))
+ m map check_genders_s(m.info, m.name, genders)
+ }
+ errors.trigger
+ c
+ }
}
object CheckWidths extends Pass {
- def name = "Width Check"
- var mname = ""
- class UninferredWidth (info:Info) extends PassException(s"${info} : [module ${mname}] Uninferred width.")
- class WidthTooSmall(info: Info, b: BigInt) extends PassException(
- s"$info : [module $mname] Width too small for constant " +
- serialize(b) + ".")
- class NegWidthException(info:Info) extends PassException(s"${info}: [module ${mname}] Width cannot be negative or zero.")
- class BitsWidthException(info: Info, hi: BigInt, width: BigInt) extends PassException(s"${info}: [module ${mname}] High bit $hi in bits operator is larger than input width $width.")
- class HeadWidthException(info: Info, n: BigInt, width: BigInt) extends PassException(s"${info}: [module ${mname}] Parameter $n in head operator is larger than input width $width.")
- class TailWidthException(info: Info, n: BigInt, width: BigInt) extends PassException(s"${info}: [module ${mname}] Parameter $n in tail operator is larger than input width $width.")
- def run (c:Circuit): Circuit = {
- val errors = new Errors()
- def check_width_m (m:DefModule) : Unit = {
- def check_width_w (info:Info)(w:Width) : Width = {
- (w) match {
- case (w:IntWidth)=> if (w.width <= 0) errors.append(new NegWidthException(info))
- case (w) => errors.append(new UninferredWidth(info))
- }
- w
- }
- def check_width_e (info:Info)(e:Expression) : Expression = {
- (e map (check_width_e(info))) match {
- case (e:UIntLiteral) => {
- (e.width) match {
- case (w:IntWidth) =>
- if (scala.math.max(1,e.value.bitLength) > w.width) {
- errors.append(new WidthTooSmall(info, e.value))
- }
- case (w) => errors.append(new UninferredWidth(info))
- }
- check_width_w(info)(e.width)
- }
- case (e:SIntLiteral) => {
- (e.width) match {
- case (w:IntWidth) =>
- if (e.value.bitLength + 1 > w.width) errors.append(new WidthTooSmall(info, e.value))
- case (w) => errors.append(new UninferredWidth(info))
- }
- check_width_w(info)(e.width)
- }
- case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if(long_BANG(a.tpe) <= hi) =>
- errors.append(new BitsWidthException(info, hi, long_BANG(a.tpe)))
- case DoPrim(Head, Seq(a), Seq(n), _) if(long_BANG(a.tpe) < n) =>
- errors.append(new HeadWidthException(info, n, long_BANG(a.tpe)))
- case DoPrim(Tail, Seq(a), Seq(n), _) if(long_BANG(a.tpe) <= n) =>
- errors.append(new TailWidthException(info, n, long_BANG(a.tpe)))
- case (e:DoPrim) => false
- case (e) => false
- }
- e
- }
- def check_width_s (s:Statement) : Statement = {
- s map (check_width_s) map (check_width_e(get_info(s)))
- def tm (t:Type) : Type = mapr(check_width_w(get_info(s)) _,t)
- s map (tm)
- }
-
- for (p <- m.ports) {
- mapr(check_width_w(p.info) _,p.tpe)
- }
-
- (m) match {
- case (m:ExtModule) => {}
- case (m:Module) => check_width_s(m.body)
- }
- }
-
- for (m <- c.modules) {
- mname = m.name
- check_width_m(m)
+ def name = "Width Check"
+ class UninferredWidth (info: Info, mname: String) extends PassException(
+ s"${info} : [module ${mname}] Uninferred width.")
+ class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException(
+ s"$info : [module $mname] Width too small for constant ${serialize(b)}.")
+ class NegWidthException(info:Info, mname: String) extends PassException(
+ s"${info}: [module ${mname}] Width cannot be negative or zero.")
+ class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt) extends PassException(
+ s"${info}: [module ${mname}] High bit $hi in bits operator is larger than input width $width.")
+ class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException(
+ s"${info}: [module ${mname}] Parameter $n in head operator is larger than input width $width.")
+ class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException(
+ s"${info}: [module ${mname}] Parameter $n in tail operator is larger than input width $width.")
+
+ def run(c: Circuit): Circuit = {
+ val errors = new Errors()
+
+ def check_width_w(info: Info, mname: String)(w: Width): Width = {
+ w match {
+ case w: IntWidth if w.width > 0 =>
+ case _: IntWidth =>
+ errors append new NegWidthException(info, mname)
+ case _ =>
+ errors append new UninferredWidth(info, mname)
}
- errors.trigger
- c
- }
+ w
+ }
+
+ def check_width_e(info: Info, mname: String)(e: Expression): Expression = {
+ e match {
+ case e: UIntLiteral => e.width match {
+ case w: IntWidth if math.max(1, e.value.bitLength) > w.width =>
+ errors append new WidthTooSmall(info, mname, e.value)
+ case _ =>
+ }
+ case e: SIntLiteral => e.width match {
+ case w: IntWidth if e.value.bitLength + 1 > w.width =>
+ errors append new WidthTooSmall(info, mname, e.value)
+ case _ =>
+ }
+ case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if long_BANG(a.tpe) <= hi =>
+ errors append new BitsWidthException(info, mname, hi, long_BANG(a.tpe))
+ case DoPrim(Head, Seq(a), Seq(n), _) if long_BANG(a.tpe) < n =>
+ errors append new HeadWidthException(info, mname, n, long_BANG(a.tpe))
+ case DoPrim(Tail, Seq(a), Seq(n), _) if long_BANG(a.tpe) <= n =>
+ errors append new TailWidthException(info, mname, n, long_BANG(a.tpe))
+ case _ =>
+ }
+ e map check_width_w(info, mname) map check_width_e(info, mname)
+ }
+
+ def check_width_s(minfo: Info, mname: String)(s: Statement): Statement = {
+ val info = get_info(s) match { case NoInfo => minfo case x => x }
+ s map check_width_e(info, mname) map check_width_s(info, mname)
+ }
+
+ def check_width_p(minfo: Info, mname: String)(p: Port): Port = {
+ p.tpe map check_width_w(p.info, mname)
+ p
+ }
+
+ def check_width_m(m: DefModule) {
+ m map check_width_p(m.info, m.name) map check_width_s(m.info, m.name)
+ }
+
+ c.modules foreach check_width_m
+ errors.trigger
+ c
+ }
}
diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala
index 515bbfc8..e2eaf690 100644
--- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala
+++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala
@@ -60,7 +60,7 @@ class CheckInitializationSpec extends FirrtlFlatSpec {
| wire x : UInt<32>
| when p :
| x <= UInt(1)""".stripMargin
- intercept[PassExceptions] {
+ intercept[CheckInitialization.RefNotInitializedException] {
passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
}
@@ -75,7 +75,7 @@ class CheckInitializationSpec extends FirrtlFlatSpec {
| when p :
| else :
| x <= UInt(1)""".stripMargin
- intercept[PassExceptions] {
+ intercept[CheckInitialization.RefNotInitializedException] {
passes.foldLeft(parse(input)) {
(c: Circuit, p: Pass) => p.run(c)
}