aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/InferWidths.scala
diff options
context:
space:
mode:
authorAlbert Chen2019-02-22 15:30:27 -0800
committermergify[bot]2019-02-22 23:30:27 +0000
commit5608aa8f42c1d69b59bee158d14fc6cef9b19a47 (patch)
tree86b7bad9c5f164d12aba9f324bde223e7ff5e9f3 /src/main/scala/firrtl/passes/InferWidths.scala
parent0ace0218d3151df2d102463dd682128a88ae7be6 (diff)
Add Width Constraints with Annotations (#956)
* refactor InferWidths to allow for extra contraints, add InferWidthsWithAnnos * add test cases * add ResolvedAnnotationPaths trait to InferWidthsWithAnnos * remove println * cleanup tests * remove extraneous constraints * use foreachStmt instead of mapStmt * remove support for aggregates * fold InferWidthsWithAnnos into InferWidths * throw exception if ref not found, check for annos before AST walk
Diffstat (limited to 'src/main/scala/firrtl/passes/InferWidths.scala')
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala121
1 files changed, 98 insertions, 23 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 6652c1fe..06833bc0 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -7,12 +7,38 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.immutable.ListMap
import firrtl._
+import firrtl.annotations.{Annotation, ReferenceTarget, TargetToken}
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
-object InferWidths extends Pass {
+case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarget) extends Annotation {
+ def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = {
+ val newLoc :: newExp :: Nil = Seq(loc, exp).map { target =>
+ renameMap.get(target) match {
+ case None => Some(target)
+ case Some(Seq()) => None
+ case Some(Seq(one)) => Some(one)
+ case Some(many) =>
+ throw new Exception(s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ }
+ }
+
+ (newLoc, newExp) match {
+ case (Some(l: ReferenceTarget), Some(e: ReferenceTarget)) => Seq(WidthGeqConstraintAnnotation(l, e))
+ case _ => Seq.empty
+ }
+ }
+}
+
+class InferWidths extends Transform with ResolvedAnnotationPaths {
+ def inputForm: CircuitForm = UnknownForm
+ def outputForm: CircuitForm = UnknownForm
+
+ val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
+
type ConstraintMap = collection.mutable.LinkedHashMap[String, Width]
def solve_constraints(l: Seq[WGeq]): ConstraintMap = {
@@ -220,26 +246,26 @@ object InferWidths extends Pass {
}
b
}
-
- def run (c: Circuit): Circuit = {
- val v = ArrayBuffer[WGeq]()
-
- def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
- case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
- case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
- case (ClockType, ClockType) => Nil
- case (AsyncResetType, AsyncResetType) => Nil
- case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
- case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1))
- case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
- res ++ (f1.flip match {
- case Default => get_constraints_t(f1.tpe, f2.tpe)
- case Flip => get_constraints_t(f2.tpe, f1.tpe)
- })
- }
- case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
- }
+
+ def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
+ case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
+ case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
+ case (ClockType, ClockType) => Nil
+ case (AsyncResetType, AsyncResetType) => Nil
+ case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
+ case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1))
+ case (t1: BundleType, t2: BundleType) =>
+ (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
+ res ++ (f1.flip match {
+ case Default => get_constraints_t(f1.tpe, f2.tpe)
+ case Flip => get_constraints_t(f2.tpe, f1.tpe)
+ })
+ }
+ case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
+ }
+
+ def run(c: Circuit, extra: Seq[WGeq]): Circuit = {
+ val v = ArrayBuffer[WGeq]() ++ extra
def get_constraints_e(e: Expression): Unit = {
e match {
@@ -364,10 +390,59 @@ object InferWidths extends Pass {
def reduce_var_widths_p(p: Port): Port = {
Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe))
- }
-
+ }
+
InferTypes.run(c.copy(modules = c.modules map (_
map reduce_var_widths_p
map reduce_var_widths_s)))
}
+
+ def execute(state: CircuitState): CircuitState = {
+ val circuitName = state.circuit.main
+ val typeMap = new collection.mutable.HashMap[ReferenceTarget, Type]
+
+ def getDeclTypes(modName: String)(stmt: Statement): Unit = {
+ val pairOpt = stmt match {
+ case w: DefWire => Some(w.name -> w.tpe)
+ case r: DefRegister => Some(r.name -> r.tpe)
+ case n: DefNode => Some(n.name -> n.value.tpe)
+ case i: WDefInstance => Some(i.name -> i.tpe)
+ case m: DefMemory => Some(m.name -> MemPortUtils.memType(m))
+ case other => None
+ }
+ pairOpt.foreach { case (ref, tpe) =>
+ typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe)
+ }
+ stmt.foreachStmt(getDeclTypes(modName))
+ }
+
+ if (state.annotations.exists(_.isInstanceOf[WidthGeqConstraintAnnotation])) {
+ state.circuit.modules.foreach { mod =>
+ mod.ports.foreach { port =>
+ typeMap += (ReferenceTarget(circuitName, mod.name, Nil, port.name, Nil) -> port.tpe)
+ }
+ mod.foreachStmt(getDeclTypes(mod.name))
+ }
+ }
+
+ val extraConstraints = state.annotations.flatMap {
+ case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
+ val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target =>
+ val baseType = typeMap.getOrElse(target.copy(component = Seq.empty),
+ throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint()))
+ val leafType = target.componentType(baseType)
+ if (leafType.isInstanceOf[AggregateType]) {
+ throw new Exception(s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ }
+
+ leafType
+ }
+
+ get_constraints_t(locType, expType)
+ case other => Seq.empty
+ }
+
+ state.copy(circuit = run(state.circuit, extraConstraints))
+ }
}