diff options
| author | Donggyu Kim | 2016-08-30 18:51:17 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-07 16:58:06 -0700 |
| commit | d7bf6fb7b415d35f967d247119b8975c3dc885a3 (patch) | |
| tree | be6afdc75f2f209f4a412d5aafae5015da98cc2a /src | |
| parent | 296a65ebb895d100c3cbde6df7c0303d6942e5d5 (diff) | |
refactor checks
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 53 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckChirrtl.scala | 189 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckInitialization.scala | 38 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 1168 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/CheckInitializationSpec.scala | 4 |
5 files changed, 663 insertions, 789 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index ea8ff4b7..29c37294 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -662,3 +662,56 @@ class MemoizedHash[T](val t: T) { case _ => false } } + +/** + * Maintains a one to many graph of each modules instantiated child module. + * This graph can be searched for a path from a child module back to one of + * it's parents. If one is found a recursive loop has happened + * The graph is a map between the name of a node to set of names of that nodes children + */ +class ModuleGraph { + val nodes = HashMap[String, HashSet[String]]() + + /** + * Add a child to a parent node + * A parent node is created if it does not already exist + * + * @param parent module that instantiates another module + * @param child module instantiated by parent + * @return a list indicating a path from child to parent, empty if no such path + */ + def add(parent: String, child: String): List[String] = { + val childSet = nodes.getOrElseUpdate(parent, new HashSet[String]) + childSet += child + pathExists(child, parent, List(child, parent)) + } + + /** + * Starting at the name of a given child explore the tree of all children in depth first manner. + * Return the first path (a list of strings) that goes from child to parent, + * or an empty list of no such path is found. + * + * @param child starting name + * @param parent name to find in children (recursively) + * @param path + * @return + */ + def pathExists(child: String, parent: String, path: List[String] = Nil): List[String] = { + nodes.get(child) match { + case Some(children) => + if(children(parent)) { + parent :: path + } + else { + children.foreach { grandchild => + val newPath = pathExists(grandchild, parent, grandchild :: path) + if(newPath.nonEmpty) { + return newPath + } + } + Nil + } + case _ => Nil + } + } +} diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index e0e7c57a..2ab8749b 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -30,8 +30,7 @@ package firrtl.passes import com.typesafe.scalalogging.LazyLogging // Datastructures -import scala.collection.mutable.HashMap -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet import firrtl._ import firrtl.ir._ @@ -44,122 +43,120 @@ import firrtl.WrappedType._ object CheckChirrtl extends Pass with LazyLogging { def name = "Chirrtl Check" + class NotUniqueException(info: Info, mname: String, name: String) extends PassException( + s"${info}: [module ${mname}] Reference ${name} does not have a unique name.") + class InvalidLOCException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Invalid connect to an expression that is not a reference or a WritePort.") + class UndeclaredReferenceException(info: Info, mname: String, name: String) extends PassException( + s"${info}: [module ${mname}] Reference ${name} is not declared.") + class MemWithFlipException(info: Info, mname: String, name: String) extends PassException( + s"${info}: [module ${mname}] Memory ${name} cannot be a bundle type with flips.") + class InvalidAccessException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Invalid access to non-reference.") + class ModuleNotDefinedException(info: Info, mname: String, name: String) extends PassException( + s"${info}: Module ${name} is not defined.") + class NegWidthException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Width cannot be negative or zero.") + class NegVecSizeException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Vector type size cannot be negative.") + class NegMemSizeException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Memory size cannot be negative or zero.") + class NoTopModuleException(info: Info, name: String) extends PassException( + s"${info}: A single module must be named ${name}.") + // TODO FIXME // - Do we need to check for uniquness on port names? def run (c: Circuit): Circuit = { - var mname: String = "" - var sinfo: Info = NoInfo - - class NotUniqueException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} does not have a unique name.") - class InvalidLOCException extends PassException(s"${sinfo}: [module ${mname}] Invalid connect to an expression that is not a reference or a WritePort.") - class UndeclaredReferenceException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} is not declared.") - class MemWithFlipException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Memory ${name} cannot be a bundle type with flips.") - class InvalidAccessException extends PassException(s"${sinfo}: [module ${mname}] Invalid access to non-reference.") - class NoTopModuleException(name: String) extends PassException(s"${sinfo}: A single module must be named ${name}.") - class ModuleNotDefinedException(name: String) extends PassException(s"${sinfo}: Module ${name} is not defined.") - class NegWidthException extends PassException(s"${sinfo}: [module ${mname}] Width cannot be negative or zero.") - class NegVecSizeException extends PassException(s"${sinfo}: [module ${mname}] Vector type size cannot be negative.") - class NegMemSizeException extends PassException(s"${sinfo}: [module ${mname}] Memory size cannot be negative or zero.") - val errors = new Errors() - def checkValidLoc(e: Expression) = e match { - case e @ (_: UIntLiteral | _: SIntLiteral | _: DoPrim ) => errors.append(new InvalidLOCException) + val moduleNames = (c.modules map (_.name)).toSet + + def checkValidLoc(info: Info, mname: String, e: Expression) = e match { + case _: UIntLiteral | _: SIntLiteral | _: DoPrim => + errors append new InvalidLOCException(info, mname) case _ => // Do Nothing } - def checkChirrtlW(w: Width): Width = w match { - case w: IntWidth if (w.width <= BigInt(0)) => - errors.append(new NegWidthException) + + def checkChirrtlW(info: Info, mname: String)(w: Width): Width = w match { + case w: IntWidth if w.width <= 0 => + errors append new NegWidthException(info, mname) w case _ => w } - def checkChirrtlT(t: Type): Type = { - t map (checkChirrtlT) match { - case t: VectorType if (t.size < 0) => errors.append(new NegVecSizeException) + + def checkChirrtlT(info: Info, mname: String)(t: Type): Type = { + t map checkChirrtlT(info, mname) match { + case t: VectorType if t.size < 0 => + errors append new NegVecSizeException(info, mname) case _ => // Do nothing } - t map (checkChirrtlW) + t map checkChirrtlW(info, mname) map checkChirrtlT(info, mname) } - def checkChirrtlM(m: DefModule): DefModule = { - val names = HashMap[String, Boolean]() - val mnames = HashMap[String, Boolean]() - def checkChirrtlE(e: Expression): Expression = { - def validSubexp(e: Expression): Expression = e match { - case (_:Reference|_:SubField|_:SubIndex|_:SubAccess|_:Mux|_:ValidIf) => e // No error - case _ => - errors.append(new InvalidAccessException) - e - } - e map (checkChirrtlE) match { - case e: Reference if (!names.contains(e.name)) => errors.append(new UndeclaredReferenceException(e.name)) - case e: DoPrim => {} - case (_:Mux|_:ValidIf) => {} - case e: SubAccess => - validSubexp(e.expr) - e - case e: UIntLiteral => {} - case e => e map (validSubexp) - } - e map (checkChirrtlW) - e map (checkChirrtlT) - e - } - def checkChirrtlS(s: Statement): Statement = { - sinfo = get_info(s) - def checkName(name: String): String = { - if (names.contains(name)) errors.append(new NotUniqueException(name)) - else names(name) = true - name - } - - s map (checkName) - s map (checkChirrtlT) - s map (checkChirrtlE) - s match { - case s: DefMemory => - if (hasFlip(s.dataType)) errors.append(new MemWithFlipException(s.name)) - if (s.depth <= 0) errors.append(new NegMemSizeException) - case s: DefInstance => - if (!c.modules.map(_.name).contains(s.module)) - errors.append(new ModuleNotDefinedException(s.module)) - case s: Connect => checkValidLoc(s.loc) - case s: PartialConnect => checkValidLoc(s.loc) - case s: Print => {} - case _ => // Do Nothing - } - - s map (checkChirrtlS) + def validSubexp(info: Info, mname: String)(e: Expression): Expression = { + e match { + case _: Reference | _: SubField | _: SubIndex | _: SubAccess | + _: Mux | _: ValidIf => // No error + case _ => errors append new InvalidAccessException(info, mname) } + e + } - mname = m.name - for (m <- c.modules) { - mnames(m.name) = true - } - for (p <- m.ports) { - sinfo = p.info - names(p.name) = true - val tpe = p.tpe - tpe map (checkChirrtlT) - tpe map (checkChirrtlW) + def checkChirrtlE(info: Info, mname: String, names: HashSet[String])(e: Expression): Expression = { + e match { + case _: DoPrim | _:Mux | _:ValidIf | _: UIntLiteral => + case e: Reference if !names(e.name) => + errors append new UndeclaredReferenceException(info, mname, e.name) + case e: SubAccess => validSubexp(info, mname)(e.expr) + case e => e map validSubexp(info, mname) } + (e map checkChirrtlW(info, mname) + map checkChirrtlT(info, mname) + map checkChirrtlE(info, mname, names)) + } + + def checkName(info: Info, mname: String, names: HashSet[String])(name: String): String = { + if (names(name)) + errors append (new NotUniqueException(info, mname, name)) + names += name + name + } - m match { - case m: Module => checkChirrtlS(m.body) - case m: ExtModule => // Do Nothing + def checkChirrtlS(minfo: Info, mname: String, names: HashSet[String])(s: Statement): Statement = { + val info = get_info(s) match {case NoInfo => minfo case x => x} + (s map checkName(info, mname, names)) match { + case s: DefMemory => + if (hasFlip(s.dataType)) errors append new MemWithFlipException(info, mname, s.name) + if (s.depth <= 0) errors append new NegMemSizeException(info, mname) + case s: DefInstance if !moduleNames(s.module) => + errors append new ModuleNotDefinedException(info, mname, s.module) + case s: Connect => checkValidLoc(info, mname, s.loc) + case s: PartialConnect => checkValidLoc(info, mname, s.loc) + case _ => // Do Nothing } - m + (s map checkChirrtlT(info, mname) + map checkChirrtlE(info, mname, names) + map checkChirrtlS(info, mname, names)) + } + + def checkChirrtlP(mname: String, names: HashSet[String])(p: Port): Port = { + names += p.name + (p.tpe map checkChirrtlT(p.info, mname) + map checkChirrtlW(p.info, mname)) + p + } + + def checkChirrtlM(m: DefModule) { + val names = HashSet[String]() + (m map checkChirrtlP(m.name, names) + map checkChirrtlS(m.info, m.name, names)) } - var numTopM = 0 - for (m <- c.modules) { - if (m.name == c.main) numTopM = numTopM + 1 - checkChirrtlM(m) + c.modules foreach checkChirrtlM + (c.modules filter (_.name == c.main)).size match { + case 1 => + case _ => errors append new NoTopModuleException(c.info, c.main) } - sinfo = c.info - if (numTopM != 1) errors.append(new NoTopModuleException(c.main)) errors.trigger c } } - -// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 6d69b792..69629bf0 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -60,8 +60,7 @@ object CheckInitialization extends Pass { } def run(c: Circuit): Circuit = { - val errors = collection.mutable.ArrayBuffer[PassException]() - + val errors = new Errors() def checkInitM(m: Module): Unit = { val voidExprs = collection.mutable.HashMap[WrappedExpression, VoidExpr]() @@ -69,19 +68,17 @@ object CheckInitialization extends Pass { def hasVoidExpr(e: Expression): (Boolean, Seq[Expression]) = { var void = false val voidDeps = collection.mutable.ArrayBuffer[Expression]() - def hasVoid(e: Expression): Expression = { - e match { - case e: WVoid => + def hasVoid(e: Expression): Expression = e match { + case e: WVoid => + void = true + e + case (_: WRef | _: WSubField) => + if (voidExprs.contains(e)) { void = true - e - case (_: WRef | _: WSubField) => - if (voidExprs.contains(e)) { - void = true - voidDeps += e - } - e - case e => e map hasVoid - } + voidDeps += e + } + e + case e => e map hasVoid } hasVoid(e) (void, voidDeps) @@ -110,19 +107,16 @@ object CheckInitialization extends Pass { case node: DefNode => // Ignore nodes case decl: IsDeclaration => val trace = getTrace(expr, voidExprs.toMap) - errors += new RefNotInitializedException(decl.info, m.name, decl.name, trace) + errors append new RefNotInitializedException(decl.info, m.name, decl.name, trace) } } } - c.modules foreach { m => - m match { - case m: Module => checkInitM(m) - case m => // Do nothing - } + c.modules foreach { + case m: Module => checkInitM(m) + case m => // Do nothing } - - if (errors.nonEmpty) throw new PassExceptions(errors) + errors.trigger c } } diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 6e49ce93..bba3efe7 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -29,746 +29,576 @@ package firrtl.passes import com.typesafe.scalalogging.LazyLogging -// Datastructures -import scala.collection.mutable.{HashMap,HashSet} -import scala.collection.mutable.ArrayBuffer - import firrtl._ import firrtl.ir._ +import firrtl.PrimOps._ import firrtl.Utils._ import firrtl.Mappers._ -import firrtl.PrimOps._ import firrtl.WrappedType._ +// Datastructures +import scala.collection.mutable.{HashMap, HashSet} + object CheckHighForm extends Pass with LazyLogging { def name = "High Form Check" // Custom Exceptions - class NotUniqueException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} does not have a unique name.") - class InvalidLOCException extends PassException(s"${sinfo}: [module ${mname}] Invalid connect to an expression that is not a reference or a WritePort.") - class NegUIntException extends PassException(s"${sinfo}: [module ${mname}] UIntLiteral cannot be negative.") - class UndeclaredReferenceException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} is not declared.") - class PoisonWithFlipException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Poison ${name} cannot be a bundle type with flips.") - class MemWithFlipException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Memory ${name} cannot be a bundle type with flips.") - class InvalidAccessException extends PassException(s"${sinfo}: [module ${mname}] Invalid access to non-reference.") - class NoTopModuleException(name: String) extends PassException(s"${sinfo}: A single module must be named ${name}.") - class ModuleNotDefinedException(name: String) extends PassException(s"${sinfo}: Module ${name} is not defined.") - class IncorrectNumArgsException(op: String, n: Int) extends PassException(s"${sinfo}: [module ${mname}] Primop ${op} requires ${n} expression arguments.") - class IncorrectNumConstsException(op: String, n: Int) extends PassException(s"${sinfo}: [module ${mname}] Primop ${op} requires ${n} integer arguments.") - class NegWidthException extends PassException(s"${sinfo}: [module ${mname}] Width cannot be negative or zero.") - class NegVecSizeException extends PassException(s"${sinfo}: [module ${mname}] Vector type size cannot be negative.") - class NegMemSizeException extends PassException(s"${sinfo}: [module ${mname}] Memory size cannot be negative or zero.") - class BadPrintfException(x: Char) extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: " + "\"%" + x + "\"") - class BadPrintfTrailingException extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: trailing " + "\"%\"") - class BadPrintfIncorrectNumException extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: incorrect number of arguments") - class InstanceLoop(loop: String) extends PassException(s"${sinfo}: [module ${mname}] Has instance loop $loop") - - /** - * Maintains a one to many graph of each modules instantiated child module. - * This graph can be searched for a path from a child module back to one of - * it's parents. If one is found a recursive loop has happened - * The graph is a map between the name of a node to set of names of that nodes children - */ - class ModuleGraph { - val nodes = new HashMap[String, HashSet[String]] - - /** - * Add a child to a parent node - * A parent node is created if it does not already exist - * - * @param parent module that instantiates another module - * @param child module instantiated by parent - * @return a list indicating a path from child to parent, empty if no such path - */ - def add(parent: String, child: String): List[String] = { - val childSet = nodes.getOrElseUpdate(parent, new HashSet[String]) - childSet += child - pathExists(child, parent, List(child, parent)) - } - - /** - * Starting at the name of a given child explore the tree of all children in depth first manner. - * Return the first path (a list of strings) that goes from child to parent, - * or an empty list of no such path is found. - * - * @param child starting name - * @param parent name to find in children (recursively) - * @param path - * @return - */ - def pathExists(child: String, parent: String, path: List[String] = Nil): List[String] = { - nodes.get(child) match { - case Some(children) => - if(children.contains(parent)) { - parent :: path - } - else { - children.foreach { grandchild => - val newPath = pathExists(grandchild, parent, grandchild :: path) - if(newPath.nonEmpty) { - return newPath - } - } - Nil - } - case _ => Nil - } - } - } + class NotUniqueException(info: Info, mname: String, name: String) extends PassException( + s"${info}: [module ${mname}] Reference ${name} does not have a unique name.") + class InvalidLOCException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Invalid connect to an expression that is not a reference or a WritePort.") + class NegUIntException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] UIntLiteral cannot be negative.") + class UndeclaredReferenceException(info: Info, mname: String, name: String) extends PassException( + s"${info}: [module ${mname}] Reference ${name} is not declared.") + class PoisonWithFlipException(info: Info, mname: String, name: String) extends PassException( + s"${info}: [module ${mname}] Poison ${name} cannot be a bundle type with flips.") + class MemWithFlipException(info: Info, mname: String, name: String) extends PassException( + s"${info}: [module ${mname}] Memory ${name} cannot be a bundle type with flips.") + class InvalidAccessException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Invalid access to non-reference.") + class ModuleNotDefinedException(info: Info, mname: String, name: String) extends PassException( + s"${info}: Module ${name} is not defined.") + class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int) extends PassException( + s"${info}: [module ${mname}] Primop ${op} requires ${n} expression arguments.") + class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int) extends PassException( + s"${info}: [module ${mname}] Primop ${op} requires ${n} integer arguments.") + class NegWidthException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Width cannot be negative or zero.") + class NegVecSizeException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Vector type size cannot be negative.") + class NegMemSizeException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Memory size cannot be negative or zero.") + class BadPrintfException(info: Info, mname: String, x: Char) extends PassException( + s"${info}: [module ${mname}] Bad printf format: " + "\"%" + x + "\"") + class BadPrintfTrailingException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Bad printf format: trailing " + "\"%\"") + class BadPrintfIncorrectNumException(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Bad printf format: incorrect number of arguments") + class InstanceLoop(info: Info, mname: String, loop: String) extends PassException( + s"${info}: [module ${mname}] Has instance loop $loop") + class NoTopModuleException(info: Info, name: String) extends PassException( + s"${info}: A single module must be named ${name}.") // TODO FIXME // - Do we need to check for uniquness on port names? - // Global Variables - private var mname: String = "" - private var sinfo: Info = NoInfo - def run (c:Circuit): Circuit = { + def run(c: Circuit): Circuit = { val errors = new Errors() val moduleGraph = new ModuleGraph + val moduleNames = (c.modules map (_.name)).toSet - def checkHighFormPrimop(e: DoPrim) = { - def correctNum(ne: Option[Int], nc: Int) = { + def checkHighFormPrimop(info: Info, mname: String, e: DoPrim) { + def correctNum(ne: Option[Int], nc: Int) { ne match { - case Some(i) => if(e.args.length != i) errors.append(new IncorrectNumArgsException(e.op.toString, i)) - case None => // Do Nothing + case Some(i) if e.args.length != i => + errors append (new IncorrectNumArgsException(info, mname, e.op.toString, i)) + case _ => // Do Nothing } - if (e.consts.length != nc) errors.append(new IncorrectNumConstsException(e.op.toString, nc)) + if (e.consts.length != nc) + errors append new IncorrectNumConstsException(info, mname, e.op.toString, nc) } e.op match { - case Add => correctNum(Option(2),0) - case Sub => correctNum(Option(2),0) - case Mul => correctNum(Option(2),0) - case Div => correctNum(Option(2),0) - case Rem => correctNum(Option(2),0) - case Lt => correctNum(Option(2),0) - case Leq => correctNum(Option(2),0) - case Gt => correctNum(Option(2),0) - case Geq => correctNum(Option(2),0) - case Eq => correctNum(Option(2),0) - case Neq => correctNum(Option(2),0) - case Pad => correctNum(Option(1),1) - case AsUInt => correctNum(Option(1),0) - case AsSInt => correctNum(Option(1),0) - case AsClock => correctNum(Option(1),0) - case Shl => correctNum(Option(1),1) - case Shr => correctNum(Option(1),1) - case Dshl => correctNum(Option(2),0) - case Dshr => correctNum(Option(2),0) - case Cvt => correctNum(Option(1),0) - case Neg => correctNum(Option(1),0) - case Not => correctNum(Option(1),0) - case And => correctNum(Option(2),0) - case Or => correctNum(Option(2),0) - case Xor => correctNum(Option(2),0) - case Andr => correctNum(None,0) - case Orr => correctNum(None,0) - case Xorr => correctNum(None,0) - case Cat => correctNum(Option(2),0) - case Bits => correctNum(Option(1),2) - case Head => correctNum(Option(1),1) - case Tail => correctNum(Option(1),1) + case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq | + Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat => + correctNum(Option(2), 0) + case AsUInt | AsSInt | AsClock | Cvt | Neq | Not => + correctNum(Option(1), 0) + case Pad | Shl | Shr | Head | Tail => + correctNum(Option(1), 1) + case Bits => + correctNum(Option(1), 2) + case Andr | Orr | Xorr => + correctNum(None,0) } } - def checkFstring(s: StringLit, i: Int) = { + def checkFstring(info: Info, mname: String, s: StringLit, i: Int) { val validFormats = "bdxc" - var percent = false - var npercents = 0 - s.array.foreach { b => - if (percent) { - if (validFormats.contains(b)) npercents += 1 - else if (b != '%') errors.append(new BadPrintfException(b.toChar)) - } - percent = if (b == '%') !percent else false // %% -> percent = false - } - if (percent) errors.append(new BadPrintfTrailingException) - if (npercents != i) errors.append(new BadPrintfIncorrectNumException) + val (percent, npercents) = (s.array foldLeft (false, 0)){ + case ((percent, n), b) if percent && (validFormats contains b) => + (false, n + 1) + case ((percent, n), b) if percent && b != '%' => + errors append new BadPrintfException(info, mname, b.toChar) + (false, n) + case ((percent, n), b) => + (if (b == '%') !percent else false /* %% -> percent = false */, n) + } + if (percent) errors append new BadPrintfTrailingException(info, mname) + if (npercents != i) errors append new BadPrintfIncorrectNumException(info, mname) } - def checkValidLoc(e: Expression) = { - e match { - case e @ (_: UIntLiteral | _: SIntLiteral | _: DoPrim ) => errors.append(new InvalidLOCException) - case _ => // Do Nothing - } + + def checkValidLoc(info: Info, mname: String, e: Expression) = e match { + case _: UIntLiteral | _: SIntLiteral | _: DoPrim => + errors append new InvalidLOCException(info, mname) + case _ => // Do Nothing } - def checkHighFormW(w: Width): Width = { + + def checkHighFormW(info: Info, mname: String)(w: Width): Width = { w match { - case w: IntWidth => - if (w.width <= BigInt(0)) errors.append(new NegWidthException) - case _ => // Do Nothing + case w: IntWidth if w.width <= 0 => + errors append new NegWidthException(info, mname) + case w => // Do nothing } w } - def checkHighFormT(t: Type): Type = { - t map (checkHighFormT) match { - case t: VectorType => - if (t.size < 0) errors.append(new NegVecSizeException) + + def checkHighFormT(info: Info, mname: String)(t: Type): Type = { + t match { + case t: VectorType if t.size < 0 => + errors append new NegVecSizeException(info, mname) case _ => // Do nothing } - t map (checkHighFormW) + t map checkHighFormW(info, mname) map checkHighFormT(info, mname) } - def checkHighFormM(m: DefModule): DefModule = { - val names = HashMap[String, Boolean]() - val mnames = HashMap[String, Boolean]() - def checkHighFormE(e: Expression): Expression = { - def validSubexp(e: Expression): Expression = { - e match { - case (_:WRef|_:WSubField|_:WSubIndex|_:WSubAccess|_:Mux|_:ValidIf) => {} // No error - case _ => errors.append(new InvalidAccessException) - } - e - } - e map (checkHighFormE) match { - case e: WRef => - if (!names.contains(e.name)) errors.append(new UndeclaredReferenceException(e.name)) - case e: DoPrim => checkHighFormPrimop(e) - case (_:Mux|_:ValidIf) => {} - case e: WSubAccess => { - validSubexp(e.exp) - e - } - case e: UIntLiteral => - if (e.value < 0) errors.append(new NegUIntException) - case e => e map (validSubexp) - } - e map (checkHighFormW) - e map (checkHighFormT) - e + def validSubexp(info: Info, mname: String)(e: Expression): Expression = { + e match { + case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error + case _ => errors append new InvalidAccessException(info, mname) } - def checkHighFormS(s: Statement): Statement = { - def checkName(name: String): String = { - if (names.contains(name)) errors.append(new NotUniqueException(name)) - else names(name) = true - name - } - sinfo = get_info(s) - - s map (checkName) - s map (checkHighFormT) - s map (checkHighFormE) - s match { - case s: DefMemory => { - if (hasFlip(s.dataType)) errors.append(new MemWithFlipException(s.name)) - if (s.depth <= 0) errors.append(new NegMemSizeException) - } - case s: WDefInstance => { - if (!c.modules.map(_.name).contains(s.module)) - errors.append(new ModuleNotDefinedException(s.module)) - // Check to see if a recursive module instantiation has occured - val childToParent = moduleGraph.add(m.name, s.module) - if(childToParent.nonEmpty) { - errors.append(new InstanceLoop(childToParent.mkString("->"))) - } - } - case s: Connect => checkValidLoc(s.loc) - case s: PartialConnect => checkValidLoc(s.loc) - case s: Print => checkFstring(s.string, s.args.length) - case _ => // Do Nothing - } + e + } - s map (checkHighFormS) - } + def checkHighFormE(info: Info, mname: String, names: HashSet[String])(e: Expression): Expression = { + e match { + case e: WRef if !names(e.name) => + errors append new UndeclaredReferenceException(info, mname, e.name) + case e: UIntLiteral if e.value < 0 => + errors append new NegUIntException(info, mname) + case e: DoPrim => checkHighFormPrimop(info, mname, e) + case _: WRef | _: UIntLiteral | _: Mux | _: ValidIf => + case e: WSubAccess => validSubexp(info, mname)(e.exp) + case e => e map validSubexp(info, mname) + } + (e map checkHighFormW(info, mname) + map checkHighFormT(info, mname) + map checkHighFormE(info, mname, names)) + } - mname = m.name - for (m <- c.modules) { - mnames(m.name) = true - } - for (p <- m.ports) { - // FIXME should we set sinfo here? - names(p.name) = true - val tpe = p.tpe - tpe map (checkHighFormT) - tpe map (checkHighFormW) - } + def checkName(info: Info, mname: String, names: HashSet[String])(name: String): String = { + if (names(name)) + errors append new NotUniqueException(info, mname, name) + names += name + name + } - m match { - case m: Module => checkHighFormS(m.body) - case m: ExtModule => // Do Nothing - } - m + def checkHighFormS(minfo: Info, mname: String, names: HashSet[String])(s: Statement): Statement = { + val info = get_info(s) match {case NoInfo => minfo case x => x} + (s map checkName(info, mname, names)) match { + case s: DefMemory => + if (hasFlip(s.dataType)) + errors append new MemWithFlipException(info, mname, s.name) + if (s.depth <= 0) + errors append new NegMemSizeException(info, mname) + case s: WDefInstance => + if (!moduleNames(s.module)) + errors append new ModuleNotDefinedException(info, mname, s.module) + // Check to see if a recursive module instantiation has occured + val childToParent = moduleGraph add (mname, s.module) + if (childToParent.nonEmpty) + errors append new InstanceLoop(info, mname, childToParent mkString "->") + case s: Connect => checkValidLoc(info, mname, s.loc) + case s: PartialConnect => checkValidLoc(info, mname, s.loc) + case s: Print => checkFstring(info, mname, s.string, s.args.length) + case s => // Do Nothing + } + (s map checkHighFormT(info, mname) + map checkHighFormE(info, mname, names) + map checkHighFormS(minfo, mname, names)) + } + + def checkHighFormP(mname: String, names: HashSet[String])(p: Port): Port = { + names += p.name + (p.tpe map checkHighFormT(p.info, mname) + map checkHighFormW(p.info, mname)) + p + } + + def checkHighFormM(m: DefModule) { + val names = HashSet[String]() + (m map checkHighFormP(m.name, names) + map checkHighFormS(m.info, m.name, names)) } - var numTopM = 0 - for (m <- c.modules) { - if (m.name == c.main) numTopM = numTopM + 1 - checkHighFormM(m) + c.modules foreach checkHighFormM + (c.modules filter (_.name == c.main)).size match { + case 1 => + case _ => errors append new NoTopModuleException(c.info, c.main) } - sinfo = c.info - if (numTopM != 1) errors.append(new NoTopModuleException(c.main)) errors.trigger c } } object CheckTypes extends Pass with LazyLogging { - def name = "Check Types" - var mname = "" + def name = "Check Types" // Custom Exceptions - class SubfieldNotInBundle(info:Info, name:String) extends PassException(s"${info}: [module ${mname} ] Subfield ${name} is not in bundle.") - class SubfieldOnNonBundle(info:Info, name:String) extends PassException(s"${info}: [module ${mname}] Subfield ${name} is accessed on a non-bundle.") - class IndexTooLarge(info:Info, value:Int) extends PassException(s"${info}: [module ${mname}] Index with value ${value} is too large.") - class IndexOnNonVector(info:Info) extends PassException(s"${info}: [module ${mname}] Index illegal on non-vector type.") - class AccessIndexNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Access index must be a UInt type.") - class IndexNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Index is not of UIntType.") - class EnableNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Enable is not of UIntType.") - class InvalidConnect(info:Info, lhs:String, rhs:String) extends PassException(s"${info}: [module ${mname}] Type mismatch. Cannot connect ${lhs} to ${rhs}.") - class InvalidRegInit(info:Info) extends PassException(s"${info}: [module ${mname}] Type of init must match type of DefRegister.") - class PrintfArgNotGround(info:Info) extends PassException(s"${info}: [module ${mname}] Printf arguments must be either UIntType or SIntType.") - class ReqClk(info:Info) extends PassException(s"${info}: [module ${mname}] Requires a clock typed signal.") - class EnNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Enable must be a UIntType typed signal.") - class PredNotUInt(info:Info) extends PassException(s"${info}: [module ${mname}] Predicate not a UIntType.") - class OpNotGround(info:Info, op:String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on non-ground types.") - class OpNotUInt(info:Info, op:String,e:String) extends PassException(s"${info}: [module ${mname}] Primop ${op} requires argument ${e} to be a UInt type.") - class OpNotAllUInt(info:Info, op:String) extends PassException(s"${info}: [module ${mname}] Primop ${op} requires all arguments to be UInt type.") - class OpNotAllSameType(info:Info, op:String) extends PassException(s"${info}: [module ${mname}] Primop ${op} requires all operands to have the same type.") - class NodePassiveType(info:Info) extends PassException(s"${info}: [module ${mname}] Node must be a passive type.") - class MuxSameType(info:Info) extends PassException(s"${info}: [module ${mname}] Must mux between equivalent types.") - class MuxPassiveTypes(info:Info) extends PassException(s"${info}: [module ${mname}] Must mux between passive types.") - class MuxCondUInt(info:Info) extends PassException(s"${info}: [module ${mname}] A mux condition must be of type UInt.") - class ValidIfPassiveTypes(info:Info) extends PassException(s"${info}: [module ${mname}] Must validif a passive type.") - class ValidIfCondUInt(info:Info) extends PassException(s"${info}: [module ${mname}] A validif condition must be of type UInt.") - //;---------------- Helper Functions -------------- - def ut () : UIntType = UIntType(UnknownWidth) - def st () : SIntType = SIntType(UnknownWidth) + class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException( + s"${info}: [module ${mname} ] Subfield ${name} is not in bundle.") + class SubfieldOnNonBundle(info: Info, mname: String, name: String) extends PassException( + s"${info}: [module ${mname}] Subfield ${name} is accessed on a non-bundle.") + class IndexTooLarge(info: Info, mname: String, value: Int) extends PassException( + s"${info}: [module ${mname}] Index with value ${value} is too large.") + class IndexOnNonVector(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Index illegal on non-vector type.") + class AccessIndexNotUInt(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Access index must be a UInt type.") + class IndexNotUInt(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Index is not of UIntType.") + class EnableNotUInt(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Enable is not of UIntType.") + class InvalidConnect(info: Info, mname: String, lhs: String, rhs: String) extends PassException( + s"${info}: [module ${mname}] Type mismatch. Cannot connect ${lhs} to ${rhs}.") + class InvalidRegInit(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Type of init must match type of DefRegister.") + class PrintfArgNotGround(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Printf arguments must be either UIntType or SIntType.") + class ReqClk(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Requires a clock typed signal.") + class EnNotUInt(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Enable must be a UIntType typed signal.") + class PredNotUInt(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Predicate not a UIntType.") + class OpNotGround(info: Info, mname: String, op: String) extends PassException( + s"${info}: [module ${mname}] Primop ${op} cannot operate on non-ground types.") + class OpNotUInt(info: Info, mname: String, op: String, e: String) extends PassException( + s"${info}: [module ${mname}] Primop ${op} requires argument ${e} to be a UInt type.") + class OpNotAllUInt(info: Info, mname: String, op: String) extends PassException( + s"${info}: [module ${mname}] Primop ${op} requires all arguments to be UInt type.") + class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException( + s"${info}: [module ${mname}] Primop ${op} requires all operands to have the same type.") + class NodePassiveType(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Node must be a passive type.") + class MuxSameType(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Must mux between equivalent types.") + class MuxPassiveTypes(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Must mux between passive types.") + class MuxCondUInt(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] A mux condition must be of type UInt.") + class ValidIfPassiveTypes(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Must validif a passive type.") + class ValidIfCondUInt(info: Info, mname: String) extends PassException( + s"${info}: [module ${mname}] A validif condition must be of type UInt.") + + //;---------------- Helper Functions -------------- + def ut: UIntType = UIntType(UnknownWidth) + def st: SIntType = SIntType(UnknownWidth) - def check_types_primop (e:DoPrim, errors:Errors, info:Info) : Unit = { - def all_same_type (ls:Seq[Expression]) : Unit = { - var error = false - for (x <- ls) { - if (wt(ls.head.tpe) != wt(x.tpe)) error = true - } - if (error) errors.append(new OpNotAllSameType(info,e.op.serialize)) - } - def all_ground (ls:Seq[Expression]) : Unit = { - var error = false - for (x <- ls ) { - x.tpe match { - case _: UIntType | _: SIntType => - case _ => error = true - } - } - if (error) errors.append(new OpNotGround(info,e.op.serialize)) - } - def all_uint (ls:Seq[Expression]) : Unit = { - var error = false - for (x <- ls ) { - x.tpe match { - case _: UIntType => - case _ => error = true - } - } - if (error) errors.append(new OpNotAllUInt(info,e.op.serialize)) - } - def is_uint (x:Expression) : Unit = { - var error = false - x.tpe match { - case _: UIntType => - case _ => error = true - } - if (error) errors.append(new OpNotUInt(info,e.op.serialize,x.serialize)) + def run (c:Circuit) : Circuit = { + val errors = new Errors() + + def passive(t: Type): Boolean = t match { + case (_: UIntType |_: SIntType) => true + case (t: VectorType) => passive(t.tpe) + case (t: BundleType) => t.fields forall (x => x.flip == Default && passive(x.tpe)) + case (t) => true + } + + def check_types_primop(info: Info, mname: String, e: DoPrim) { + def all_same_type (ls:Seq[Expression]) { + if (ls exists (x => wt(ls.head.tpe) != wt(e.tpe))) + errors append new OpNotAllSameType(info, mname, e.op.serialize) + } + def all_ground (ls: Seq[Expression]) { + if (ls exists (x => x.tpe match { + case _: UIntType | _: SIntType => false + case _ => true + })) errors append new OpNotGround(info, mname, e.op.serialize) + } + def all_uint (ls: Seq[Expression]) { + if (ls exists (x => x.tpe match { + case _: UIntType => false + case _ => true + })) errors append new OpNotAllUInt(info, mname, e.op.serialize) + } + def is_uint (x:Expression) { + if (x.tpe match { + case _: UIntType => false + case _ => true + }) errors append new OpNotUInt(info, mname, e.op.serialize, x.serialize) } - + e.op match { - case AsUInt => - case AsSInt => - case AsClock => - case Dshl => is_uint(e.args(1)); all_ground(e.args) - case Dshr => is_uint(e.args(1)); all_ground(e.args) - case Add => all_ground(e.args) - case Sub => all_ground(e.args) - case Mul => all_ground(e.args) - case Div => all_ground(e.args) - case Rem => all_ground(e.args) - case Lt => all_ground(e.args) - case Leq => all_ground(e.args) - case Gt => all_ground(e.args) - case Geq => all_ground(e.args) - case Eq => all_ground(e.args) - case Neq => all_ground(e.args) - case Pad => all_ground(e.args) - case Shl => all_ground(e.args) - case Shr => all_ground(e.args) - case Cvt => all_ground(e.args) - case Neg => all_ground(e.args) - case Not => all_ground(e.args) - case And => all_ground(e.args) - case Or => all_ground(e.args) - case Xor => all_ground(e.args) - case Andr => all_ground(e.args) - case Orr => all_ground(e.args) - case Xorr => all_ground(e.args) - case Cat => all_ground(e.args) - case Bits => all_ground(e.args) - case Head => all_ground(e.args) - case Tail => all_ground(e.args) - } - } - - def run (c:Circuit) : Circuit = { - val errors = new Errors() - def passive (t:Type) : Boolean = { - (t) match { - case (_:UIntType|_:SIntType) => true - case (t:VectorType) => passive(t.tpe) - case (t:BundleType) => { - var p = true - for (x <- t.fields ) { - if (x.flip == Flip) p = false - if (!passive(x.tpe)) p = false - } - p - } - case (t) => true - } - } - def check_types_e (info:Info)(e:Expression) : Expression = { - (e map (check_types_e(info))) match { - case (e:WRef) => e - case (e:WSubField) => { - (e.exp.tpe) match { - case (t:BundleType) => { - val ft = t.fields.find(p => p.name == e.name) - if (ft == None) errors.append(new SubfieldNotInBundle(info,e.name)) - } - case (t) => errors.append(new SubfieldOnNonBundle(info,e.name)) - } - } - case (e:WSubIndex) => { - (e.exp.tpe) match { - case (t:VectorType) => { - if (e.value >= t.size) errors.append(new IndexTooLarge(info,e.value)) - } - case (t) => errors.append(new IndexOnNonVector(info)) - } - } - case (e:WSubAccess) => { - (e.exp.tpe) match { - case (t:VectorType) => false - case (t) => errors.append(new IndexOnNonVector(info)) - } - (e.index.tpe) match { - case (t:UIntType) => false - case (t) => errors.append(new AccessIndexNotUInt(info)) - } - } - case (e:DoPrim) => check_types_primop(e,errors,info) - case (e:Mux) => { - if (wt(e.tval.tpe) != wt(e.fval.tpe)) errors.append(new MuxSameType(info)) - if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info)) - e.cond.tpe match { - case _: UIntType => - case _ => errors.append(new MuxCondUInt(info)) - } - } - case (e:ValidIf) => { - if (!passive(e.tpe)) errors.append(new ValidIfPassiveTypes(info)) - e.cond.tpe match { - case _: UIntType => - case _ => errors.append(new ValidIfCondUInt(info)) - } - } - case (_:UIntLiteral | _:SIntLiteral) => false - } - e + case AsUInt | AsSInt | AsClock => + case Dshl => is_uint(e.args(1)); all_ground(e.args) + case Dshr => is_uint(e.args(1)); all_ground(e.args) + case _ => all_ground(e.args) } - - def bulk_equals (t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { - //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) - (t1,t2) match { - case (ClockType, ClockType) => flip1 == flip2 - case (t1:UIntType,t2:UIntType) => flip1 == flip2 - case (t1:SIntType,t2:SIntType) => flip1 == flip2 - case (t1:BundleType,t2:BundleType) => { - var isEqual = true - for (i <- 0 until t1.fields.size) { - for (j <- 0 until t2.fields.size) { - val f1 = t1.fields(i) - val f2 = t2.fields(j) - if (f1.name == f2.name) { - val field_equal = bulk_equals(f1.tpe,f2.tpe,times(flip1, f1.flip),times(flip2, f2.flip)) - if (!field_equal) isEqual = false - } - } - } - isEqual - } - case (t1:VectorType,t2:VectorType) => bulk_equals(t1.tpe,t2.tpe,flip1,flip2) - } + } + + def check_types_e(info:Info, mname: String)(e: Expression): Expression = { + e match { + case (e: WSubField) => e.exp.tpe match { + case (t: BundleType) => t.fields find (_.name == e.name) match { + case Some(_) => + case None => errors append new SubfieldNotInBundle(info, mname, e.name) + } + case _ => errors append new SubfieldOnNonBundle(info, mname, e.name) + } + case (e: WSubIndex) => e.exp.tpe match { + case (t: VectorType) if e.value < t.size => + case (t: VectorType) => + errors append (new IndexTooLarge(info, mname, e.value)) + case _ => + errors append (new IndexOnNonVector(info, mname)) + } + case (e: WSubAccess) => + e.exp.tpe match { + case _: VectorType => + case _ => errors append new IndexOnNonVector(info, mname) + } + e.index.tpe match { + case _: UIntType => + case _ => errors append new AccessIndexNotUInt(info, mname) + } + case (e: DoPrim) => check_types_primop(info, mname, e) + case (e: Mux) => + if (wt(e.tval.tpe) != wt(e.fval.tpe)) + errors append new MuxSameType(info, mname) + if (!passive(e.tpe)) + errors append new MuxPassiveTypes(info, mname) + e.cond.tpe match { + case _: UIntType => + case _ => errors append new MuxCondUInt(info, mname) + } + case (e: ValidIf) => + if (!passive(e.tpe)) + errors append new ValidIfPassiveTypes(info, mname) + e.cond.tpe match { + case _: UIntType => + case _ => errors append new ValidIfCondUInt(info, mname) + } + case _ => } + e map check_types_e(info, mname) + } - def check_types_s (s:Statement) : Statement = { - s map (check_types_e(get_info(s))) match { - case (s:Connect) => if (wt(s.loc.tpe) != wt(s.expr.tpe)) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) - case (s:DefRegister) => if (wt(s.tpe) != wt(s.init.tpe)) errors.append(new InvalidRegInit(s.info)) - case (s:PartialConnect) => if (!bulk_equals(s.loc.tpe,s.expr.tpe,Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) - case (s:Stop) => { - if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info)) - if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) - } - case (s:Print)=> { - for (x <- s.args ) { - if (wt(x.tpe) != wt(ut()) && wt(x.tpe) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info)) - } - if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info)) - if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) + def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { + //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) + (t1, t2) match { + case (ClockType, ClockType) => flip1 == flip2 + case (_: UIntType, _: UIntType) => flip1 == flip2 + case (_: SIntType, _: SIntType) => flip1 == flip2 + case (t1: BundleType, t2: BundleType) => + val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())( + (map, f1) => map + (f1.name -> (f1.tpe, f1.flip))) + t2.fields forall (f2 => + t1_fields get f2.name match { + case None => true + case Some((f1_tpe, f1_flip)) => + bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip)) } - case (s:Conditionally) => if (wt(s.pred.tpe) != wt(ut()) ) errors.append(new PredNotUInt(s.info)) - case (s:DefNode) => if (!passive(s.value.tpe) ) errors.append(new NodePassiveType(s.info)) - case (s) => false - } - s map (check_types_s) + ) + case (t1: VectorType, t2: VectorType) => + bulk_equals(t1.tpe, t2.tpe, flip1, flip2) } - - for (m <- c.modules ) { - mname = m.name - (m) match { - case (m:ExtModule) => false - case (m:Module) => check_types_s(m.body) - } - } - errors.trigger - c - } + } + + def check_types_s(minfo: Info, mname: String)(s: Statement): Statement = { + val info = get_info(s) match { case NoInfo => minfo case x => x } + s match { + case (s: Connect) if wt(s.loc.tpe) != wt(s.expr.tpe) => + errors append new InvalidConnect(info, mname, s.loc.serialize, s.expr.serialize) + case (s:PartialConnect) if !bulk_equals(s.loc.tpe, s.expr.tpe, Default, Default) => + errors append new InvalidConnect(info, mname, s.loc.serialize, s.expr.serialize) + case (s: DefRegister) if wt(s.tpe) != wt(s.init.tpe) => + errors append new InvalidRegInit(info, mname) + case (s: Conditionally) if wt(s.pred.tpe) != wt(ut) => + errors append new PredNotUInt(info, mname) + case (s: DefNode) if !passive(s.value.tpe) => + errors append new NodePassiveType(info, mname) + case (s: Stop) => + if (wt(s.clk.tpe) != wt(ClockType)) errors append new ReqClk(info, mname) + if (wt(s.en.tpe) != wt(ut)) errors append new EnNotUInt(info, mname) + case (s: Print) => + if (s.args exists (x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st))) + errors append new PrintfArgNotGround(info, mname) + if (wt(s.clk.tpe) != wt(ClockType)) errors append new ReqClk(info, mname) + if (wt(s.en.tpe) != wt(ut)) errors append new EnNotUInt(info, mname) + case _ => + } + s map check_types_e(info, mname) map check_types_s(info, mname) + } + + c.modules foreach (m => m map check_types_s(m.info, m.name)) + errors.trigger + c + } } object CheckGenders extends Pass { - def name = "Check Genders" - var mname = "" - class WrongGender (info:Info,expr:String,wrong:String,right:String) extends PassException(s"${info}: [module ${mname}] Expression ${expr} is used as a ${wrong} but can only be used as a ${right}.") + def name = "Check Genders" + + implicit def toStr(g: Gender): String = g match { + case MALE => "source" + case FEMALE => "sink" + case UNKNOWNGENDER => "unknown" + case BIGENDER => "sourceOrSink" + } - def dir_to_gender (d:Direction) : Gender = { - d match { - case Input => MALE - case Output => FEMALE //BI-GENDER - } - } + class WrongGender(info:Info, mname: String, expr: String, wrong: Gender, right: Gender) extends PassException( + s"${info}: [module ${mname}] Expression ${expr} is used as a ${wrong} but can only be used as a ${right}.") - def as_srcsnk (g:Gender) : String = { - g match { - case MALE => "source" - case FEMALE => "sink" - case UNKNOWNGENDER => "unknown" - case BIGENDER => "sourceOrSink" + def run (c:Circuit): Circuit = { + val errors = new Errors() + + def get_gender(e: Expression, genders: HashMap[String, Gender]): Gender = e match { + case (e: WRef) => genders(e.name) + case (e: WSubIndex) => get_gender(e.exp, genders) + case (e: WSubAccess) => get_gender(e.exp, genders) + case (e: WSubField) => e.exp.tpe match {case t: BundleType => + val f = (t.fields find (_.name == e.name)).get + times(get_gender(e.exp, genders), f.flip) + } + case _ => MALE + } + + def flip_q(t: Type): Boolean = { + def flip_rec(t: Type, f: Orientation): Boolean = t match { + case (t:BundleType) => t.fields exists ( + field => flip_rec(field.tpe, times(f, field.flip)) + ) + case t: VectorType => flip_rec(t.tpe, f) + case t => f == Flip + } + flip_rec(t, Default) + } + + def check_gender(info:Info, mname: String, + genders: HashMap[String,Gender], desired: Gender)(e:Expression): Expression = { + val gender = get_gender(e,genders) + (gender, desired) match { + case (MALE, FEMALE) => + errors append new WrongGender(info, mname, e.serialize, desired, gender) + case (FEMALE, MALE) => kind(e) match { + case _: PortKind | _: InstanceKind if !flip_q(e.tpe) => // OK! + case _ => + errors append new WrongGender(info, mname, e.serialize, desired, gender) + } + case _ => } + e } - def run (c:Circuit): Circuit = { - val errors = new Errors() - def get_kind (e:Expression) : Kind = { - (e) match { - case (e:WRef) => e.kind - case (e:WSubField) => get_kind(e.exp) - case (e:WSubIndex) => get_kind(e.exp) - case (e:WSubAccess) => get_kind(e.exp) - case (e) => NodeKind() - } - } - - def check_gender (info:Info,genders:HashMap[String,Gender],desired:Gender)(e:Expression) : Expression = { - val gender = get_gender(e,genders) - val kindx = get_kind(e) - def flipQ (t:Type) : Boolean = { - var fQ = false - def flip_rec (t:Type,f:Orientation) : Type = { - (t) match { - case (t:BundleType) => { - for (field <- t.fields) { - flip_rec(field.tpe,times(f, field.flip)) - } - } - case (t:VectorType) => flip_rec(t.tpe,f) - case (t) => if (f == Flip) fQ = true - } - t - } - flip_rec(t,Default) - fQ - } - - val has_flipQ = flipQ(e.tpe) - //println(e) - //println(gender) - //println(desired) - //println(kindx) - //println(desired == gender) - //if gender != desired and gender != BI-GENDER: - (gender,desired) match { - case (MALE, FEMALE) => errors.append(new WrongGender(info,e.serialize,as_srcsnk(desired),as_srcsnk(gender))) - case (FEMALE, MALE) => - if ((kindx == PortKind() || kindx == InstanceKind()) && has_flipQ == false) { - //; OK! - false - } else { - //; Not Ok! - errors.append(new WrongGender(info,e.serialize,as_srcsnk(desired),as_srcsnk(gender))) - } - case _ => false - } - e - } - - def get_gender (e:Expression,genders:HashMap[String,Gender]) : Gender = { - (e) match { - case (e:WRef) => genders(e.name) - case (e:WSubField) => - val f = e.exp.tpe.asInstanceOf[BundleType].fields.find(f => f.name == e.name).get - times(get_gender(e.exp,genders),f.flip) - case (e:WSubIndex) => get_gender(e.exp,genders) - case (e:WSubAccess) => get_gender(e.exp,genders) - case (e:DoPrim) => MALE - case (e:UIntLiteral) => MALE - case (e:SIntLiteral) => MALE - case (e:Mux) => MALE - case (e:ValidIf) => MALE - } - } - - def check_genders_e (info:Info,genders:HashMap[String,Gender])(e:Expression) : Expression = { - e map (check_genders_e(info,genders)) - (e) match { - case (e:WRef) => false - case (e:WSubField) => false - case (e:WSubIndex) => false - case (e:WSubAccess) => false - case (e:DoPrim) => for (e <- e.args ) { check_gender(info,genders,MALE)(e) } - case (e:Mux) => e map (check_gender(info,genders,MALE)) - case (e:ValidIf) => e map (check_gender(info,genders,MALE)) - case (e:UIntLiteral) => false - case (e:SIntLiteral) => false - } - e + def check_genders_e (info:Info, mname: String, genders: HashMap[String,Gender])(e:Expression): Expression = { + e match { + case e: Mux => e map check_gender(info, mname, genders, MALE) + case e: DoPrim => e.args map check_gender(info, mname, genders, MALE) + case _ => } + e map check_genders_e(info, mname, genders) + } - def check_genders_s (genders:HashMap[String,Gender])(s:Statement) : Statement = { - s map (check_genders_e(get_info(s),genders)) - s map (check_genders_s(genders)) - (s) match { - case (s:DefWire) => genders(s.name) = BIGENDER - case (s:DefRegister) => genders(s.name) = BIGENDER - case (s:DefNode) => { - check_gender(s.info,genders,MALE)(s.value) - genders(s.name) = MALE - } - case (s:DefMemory) => genders(s.name) = MALE - case (s:WDefInstance) => genders(s.name) = MALE - case (s:Connect) => { - check_gender(s.info,genders,FEMALE)(s.loc) - check_gender(s.info,genders,MALE)(s.expr) - } - case (s:Print) => { - for (x <- s.args ) { - check_gender(s.info,genders,MALE)(x) - } - check_gender(s.info,genders,MALE)(s.en) - check_gender(s.info,genders,MALE)(s.clk) - } - case (s:PartialConnect) => { - check_gender(s.info,genders,FEMALE)(s.loc) - check_gender(s.info,genders,MALE)(s.expr) - } - case (s:Conditionally) => { - check_gender(s.info,genders,MALE)(s.pred) - } - case EmptyStmt => false - case (s:Stop) => { - check_gender(s.info,genders,MALE)(s.en) - check_gender(s.info,genders,MALE)(s.clk) - } - case (_:Block | _:IsInvalid) => false - } - s - } + def check_genders_s(minfo: Info, mname: String, genders: HashMap[String,Gender])(s: Statement): Statement = { + val info = get_info(s) match { case NoInfo => minfo case x => x } + s match { + case (s: DefWire) => genders(s.name) = BIGENDER + case (s: DefRegister) => genders(s.name) = BIGENDER + case (s: DefMemory) => genders(s.name) = MALE + case (s: WDefInstance) => genders(s.name) = MALE + case (s: DefNode) => + check_gender(info, mname, genders, MALE)(s.value) + genders(s.name) = MALE + case (s: Connect) => + check_gender(info, mname, genders, FEMALE)(s.loc) + check_gender(info, mname, genders, MALE)(s.expr) + case (s: Print) => + s.args map check_gender(info, mname, genders, MALE) + check_gender(info, mname, genders, MALE)(s.en) + check_gender(info, mname, genders, MALE)(s.clk) + case (s: PartialConnect) => + check_gender(info, mname, genders, FEMALE)(s.loc) + check_gender(info, mname, genders, MALE)(s.expr) + case (s: Conditionally) => + check_gender(info, mname, genders, MALE)(s.pred) + case (s: Stop) => + check_gender(info, mname, genders, MALE)(s.en) + check_gender(info, mname, genders, MALE)(s.clk) + case _ => + } + s map check_genders_e(info, mname, genders) map check_genders_s(minfo, mname, genders) + } - for (m <- c.modules ) { - mname = m.name - val genders = HashMap[String,Gender]() - for (p <- m.ports) { - genders(p.name) = dir_to_gender(p.direction) - } - (m) match { - case (m:ExtModule) => false - case (m:Module) => check_genders_s(genders)(m.body) - } - } - errors.trigger - c - } + for (m <- c.modules) { + val genders = HashMap[String, Gender]() + genders ++= (m.ports map (p => p.name -> to_gender(p.direction))) + m map check_genders_s(m.info, m.name, genders) + } + errors.trigger + c + } } object CheckWidths extends Pass { - def name = "Width Check" - var mname = "" - class UninferredWidth (info:Info) extends PassException(s"${info} : [module ${mname}] Uninferred width.") - class WidthTooSmall(info: Info, b: BigInt) extends PassException( - s"$info : [module $mname] Width too small for constant " + - serialize(b) + ".") - class NegWidthException(info:Info) extends PassException(s"${info}: [module ${mname}] Width cannot be negative or zero.") - class BitsWidthException(info: Info, hi: BigInt, width: BigInt) extends PassException(s"${info}: [module ${mname}] High bit $hi in bits operator is larger than input width $width.") - class HeadWidthException(info: Info, n: BigInt, width: BigInt) extends PassException(s"${info}: [module ${mname}] Parameter $n in head operator is larger than input width $width.") - class TailWidthException(info: Info, n: BigInt, width: BigInt) extends PassException(s"${info}: [module ${mname}] Parameter $n in tail operator is larger than input width $width.") - def run (c:Circuit): Circuit = { - val errors = new Errors() - def check_width_m (m:DefModule) : Unit = { - def check_width_w (info:Info)(w:Width) : Width = { - (w) match { - case (w:IntWidth)=> if (w.width <= 0) errors.append(new NegWidthException(info)) - case (w) => errors.append(new UninferredWidth(info)) - } - w - } - def check_width_e (info:Info)(e:Expression) : Expression = { - (e map (check_width_e(info))) match { - case (e:UIntLiteral) => { - (e.width) match { - case (w:IntWidth) => - if (scala.math.max(1,e.value.bitLength) > w.width) { - errors.append(new WidthTooSmall(info, e.value)) - } - case (w) => errors.append(new UninferredWidth(info)) - } - check_width_w(info)(e.width) - } - case (e:SIntLiteral) => { - (e.width) match { - case (w:IntWidth) => - if (e.value.bitLength + 1 > w.width) errors.append(new WidthTooSmall(info, e.value)) - case (w) => errors.append(new UninferredWidth(info)) - } - check_width_w(info)(e.width) - } - case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if(long_BANG(a.tpe) <= hi) => - errors.append(new BitsWidthException(info, hi, long_BANG(a.tpe))) - case DoPrim(Head, Seq(a), Seq(n), _) if(long_BANG(a.tpe) < n) => - errors.append(new HeadWidthException(info, n, long_BANG(a.tpe))) - case DoPrim(Tail, Seq(a), Seq(n), _) if(long_BANG(a.tpe) <= n) => - errors.append(new TailWidthException(info, n, long_BANG(a.tpe))) - case (e:DoPrim) => false - case (e) => false - } - e - } - def check_width_s (s:Statement) : Statement = { - s map (check_width_s) map (check_width_e(get_info(s))) - def tm (t:Type) : Type = mapr(check_width_w(get_info(s)) _,t) - s map (tm) - } - - for (p <- m.ports) { - mapr(check_width_w(p.info) _,p.tpe) - } - - (m) match { - case (m:ExtModule) => {} - case (m:Module) => check_width_s(m.body) - } - } - - for (m <- c.modules) { - mname = m.name - check_width_m(m) + def name = "Width Check" + class UninferredWidth (info: Info, mname: String) extends PassException( + s"${info} : [module ${mname}] Uninferred width.") + class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException( + s"$info : [module $mname] Width too small for constant ${serialize(b)}.") + class NegWidthException(info:Info, mname: String) extends PassException( + s"${info}: [module ${mname}] Width cannot be negative or zero.") + class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt) extends PassException( + s"${info}: [module ${mname}] High bit $hi in bits operator is larger than input width $width.") + class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( + s"${info}: [module ${mname}] Parameter $n in head operator is larger than input width $width.") + class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( + s"${info}: [module ${mname}] Parameter $n in tail operator is larger than input width $width.") + + def run(c: Circuit): Circuit = { + val errors = new Errors() + + def check_width_w(info: Info, mname: String)(w: Width): Width = { + w match { + case w: IntWidth if w.width > 0 => + case _: IntWidth => + errors append new NegWidthException(info, mname) + case _ => + errors append new UninferredWidth(info, mname) } - errors.trigger - c - } + w + } + + def check_width_e(info: Info, mname: String)(e: Expression): Expression = { + e match { + case e: UIntLiteral => e.width match { + case w: IntWidth if math.max(1, e.value.bitLength) > w.width => + errors append new WidthTooSmall(info, mname, e.value) + case _ => + } + case e: SIntLiteral => e.width match { + case w: IntWidth if e.value.bitLength + 1 > w.width => + errors append new WidthTooSmall(info, mname, e.value) + case _ => + } + case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if long_BANG(a.tpe) <= hi => + errors append new BitsWidthException(info, mname, hi, long_BANG(a.tpe)) + case DoPrim(Head, Seq(a), Seq(n), _) if long_BANG(a.tpe) < n => + errors append new HeadWidthException(info, mname, n, long_BANG(a.tpe)) + case DoPrim(Tail, Seq(a), Seq(n), _) if long_BANG(a.tpe) <= n => + errors append new TailWidthException(info, mname, n, long_BANG(a.tpe)) + case _ => + } + e map check_width_w(info, mname) map check_width_e(info, mname) + } + + def check_width_s(minfo: Info, mname: String)(s: Statement): Statement = { + val info = get_info(s) match { case NoInfo => minfo case x => x } + s map check_width_e(info, mname) map check_width_s(info, mname) + } + + def check_width_p(minfo: Info, mname: String)(p: Port): Port = { + p.tpe map check_width_w(p.info, mname) + p + } + + def check_width_m(m: DefModule) { + m map check_width_p(m.info, m.name) map check_width_s(m.info, m.name) + } + + c.modules foreach check_width_m + errors.trigger + c + } } diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala index 515bbfc8..e2eaf690 100644 --- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala +++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala @@ -60,7 +60,7 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | wire x : UInt<32> | when p : | x <= UInt(1)""".stripMargin - intercept[PassExceptions] { + intercept[CheckInitialization.RefNotInitializedException] { passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) } @@ -75,7 +75,7 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | when p : | else : | x <= UInt(1)""".stripMargin - intercept[PassExceptions] { + intercept[CheckInitialization.RefNotInitializedException] { passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => p.run(c) } |
