aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorAlbert Chen2019-02-22 15:30:27 -0800
committermergify[bot]2019-02-22 23:30:27 +0000
commit5608aa8f42c1d69b59bee158d14fc6cef9b19a47 (patch)
tree86b7bad9c5f164d12aba9f324bde223e7ff5e9f3 /src/main
parent0ace0218d3151df2d102463dd682128a88ae7be6 (diff)
Add Width Constraints with Annotations (#956)
* refactor InferWidths to allow for extra contraints, add InferWidthsWithAnnos * add test cases * add ResolvedAnnotationPaths trait to InferWidthsWithAnnos * remove println * cleanup tests * remove extraneous constraints * use foreachStmt instead of mapStmt * remove support for aggregates * fold InferWidthsWithAnnos into InferWidths * throw exception if ref not found, check for annos before AST walk
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala6
-rw-r--r--src/main/scala/firrtl/annotations/Target.scala21
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala121
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala2
4 files changed, 122 insertions, 28 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 9969150d..262caeea 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -44,7 +44,7 @@ class ResolveAndCheck extends CoreTransform {
passes.InferTypes,
passes.ResolveGenders,
passes.CheckGenders,
- passes.InferWidths,
+ new passes.InferWidths,
passes.CheckWidths)
}
@@ -68,7 +68,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform {
passes.InferTypes,
passes.CheckTypes,
passes.ResolveGenders,
- passes.InferWidths,
+ new passes.InferWidths,
passes.CheckWidths,
passes.ConvertFixedToSInt,
passes.ZeroWidth,
@@ -87,7 +87,7 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform {
passes.ResolveKinds,
passes.InferTypes,
passes.ResolveGenders,
- passes.InferWidths,
+ new passes.InferWidths,
passes.Legalize,
new firrtl.transforms.RemoveReset,
new firrtl.transforms.CheckCombLoops,
diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala
index 8a9d68e8..0247b66c 100644
--- a/src/main/scala/firrtl/annotations/Target.scala
+++ b/src/main/scala/firrtl/annotations/Target.scala
@@ -3,7 +3,8 @@
package firrtl
package annotations
-import firrtl.ir.Expression
+import firrtl.ir.{Expression, Type}
+import firrtl.Utils.{sub_type, field_type}
import AnnotationUtils.{toExp, validComponentName, validModuleName}
import TargetToken._
@@ -553,6 +554,24 @@ case class ReferenceTarget(circuit: String,
/** @return The clock signal of this reference, must be to a [[firrtl.ir.DefRegister]] */
def clock: ReferenceTarget = ReferenceTarget(circuit, module, path, ref, component :+ Clock)
+ /** @param the type of this target's ref
+ * @return the type of the subcomponent specified by this target's component
+ */
+ def componentType(baseType: Type): Type = componentType(baseType, tokens)
+
+ private def componentType(baseType: Type, tokens: Seq[TargetToken]): Type = {
+ if (tokens.isEmpty) {
+ baseType
+ } else {
+ val headType = tokens.head match {
+ case Index(idx) => sub_type(baseType)
+ case Field(field) => field_type(baseType, field)
+ case _: Ref => baseType
+ }
+ componentType(headType, tokens.tail)
+ }
+ }
+
override def circuitOpt: Option[String] = Some(circuit)
override def moduleOpt: Option[String] = Some(module)
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 6652c1fe..06833bc0 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -7,12 +7,38 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.immutable.ListMap
import firrtl._
+import firrtl.annotations.{Annotation, ReferenceTarget, TargetToken}
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
-object InferWidths extends Pass {
+case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarget) extends Annotation {
+ def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = {
+ val newLoc :: newExp :: Nil = Seq(loc, exp).map { target =>
+ renameMap.get(target) match {
+ case None => Some(target)
+ case Some(Seq()) => None
+ case Some(Seq(one)) => Some(one)
+ case Some(many) =>
+ throw new Exception(s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ }
+ }
+
+ (newLoc, newExp) match {
+ case (Some(l: ReferenceTarget), Some(e: ReferenceTarget)) => Seq(WidthGeqConstraintAnnotation(l, e))
+ case _ => Seq.empty
+ }
+ }
+}
+
+class InferWidths extends Transform with ResolvedAnnotationPaths {
+ def inputForm: CircuitForm = UnknownForm
+ def outputForm: CircuitForm = UnknownForm
+
+ val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
+
type ConstraintMap = collection.mutable.LinkedHashMap[String, Width]
def solve_constraints(l: Seq[WGeq]): ConstraintMap = {
@@ -220,26 +246,26 @@ object InferWidths extends Pass {
}
b
}
-
- def run (c: Circuit): Circuit = {
- val v = ArrayBuffer[WGeq]()
-
- def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
- case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
- case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
- case (ClockType, ClockType) => Nil
- case (AsyncResetType, AsyncResetType) => Nil
- case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
- case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1))
- case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
- res ++ (f1.flip match {
- case Default => get_constraints_t(f1.tpe, f2.tpe)
- case Flip => get_constraints_t(f2.tpe, f1.tpe)
- })
- }
- case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
- }
+
+ def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
+ case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
+ case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
+ case (ClockType, ClockType) => Nil
+ case (AsyncResetType, AsyncResetType) => Nil
+ case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
+ case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1))
+ case (t1: BundleType, t2: BundleType) =>
+ (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
+ res ++ (f1.flip match {
+ case Default => get_constraints_t(f1.tpe, f2.tpe)
+ case Flip => get_constraints_t(f2.tpe, f1.tpe)
+ })
+ }
+ case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
+ }
+
+ def run(c: Circuit, extra: Seq[WGeq]): Circuit = {
+ val v = ArrayBuffer[WGeq]() ++ extra
def get_constraints_e(e: Expression): Unit = {
e match {
@@ -364,10 +390,59 @@ object InferWidths extends Pass {
def reduce_var_widths_p(p: Port): Port = {
Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe))
- }
-
+ }
+
InferTypes.run(c.copy(modules = c.modules map (_
map reduce_var_widths_p
map reduce_var_widths_s)))
}
+
+ def execute(state: CircuitState): CircuitState = {
+ val circuitName = state.circuit.main
+ val typeMap = new collection.mutable.HashMap[ReferenceTarget, Type]
+
+ def getDeclTypes(modName: String)(stmt: Statement): Unit = {
+ val pairOpt = stmt match {
+ case w: DefWire => Some(w.name -> w.tpe)
+ case r: DefRegister => Some(r.name -> r.tpe)
+ case n: DefNode => Some(n.name -> n.value.tpe)
+ case i: WDefInstance => Some(i.name -> i.tpe)
+ case m: DefMemory => Some(m.name -> MemPortUtils.memType(m))
+ case other => None
+ }
+ pairOpt.foreach { case (ref, tpe) =>
+ typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe)
+ }
+ stmt.foreachStmt(getDeclTypes(modName))
+ }
+
+ if (state.annotations.exists(_.isInstanceOf[WidthGeqConstraintAnnotation])) {
+ state.circuit.modules.foreach { mod =>
+ mod.ports.foreach { port =>
+ typeMap += (ReferenceTarget(circuitName, mod.name, Nil, port.name, Nil) -> port.tpe)
+ }
+ mod.foreachStmt(getDeclTypes(mod.name))
+ }
+ }
+
+ val extraConstraints = state.annotations.flatMap {
+ case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
+ val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target =>
+ val baseType = typeMap.getOrElse(target.copy(component = Seq.empty),
+ throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint()))
+ val leafType = target.componentType(baseType)
+ if (leafType.isInstanceOf[AggregateType]) {
+ throw new Exception(s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ }
+
+ leafType
+ }
+
+ get_constraints_t(locType, expType)
+ case other => Seq.empty
+ }
+
+ state.copy(circuit = run(state.circuit, extraConstraints))
+ }
}
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 04bfb19c..7f7f6e40 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -33,7 +33,7 @@ trait Pass extends Transform {
// Error handling
class PassException(message: String) extends Exception(message)
-class PassExceptions(exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n"))
+class PassExceptions(val exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n"))
class Errors {
val errors = collection.mutable.ArrayBuffer[PassException]()
def append(pe: PassException) = errors.append(pe)