diff options
| author | Albert Chen | 2019-02-22 15:30:27 -0800 |
|---|---|---|
| committer | mergify[bot] | 2019-02-22 23:30:27 +0000 |
| commit | 5608aa8f42c1d69b59bee158d14fc6cef9b19a47 (patch) | |
| tree | 86b7bad9c5f164d12aba9f324bde223e7ff5e9f3 /src/main | |
| parent | 0ace0218d3151df2d102463dd682128a88ae7be6 (diff) | |
Add Width Constraints with Annotations (#956)
* refactor InferWidths to allow for extra contraints, add InferWidthsWithAnnos
* add test cases
* add ResolvedAnnotationPaths trait to InferWidthsWithAnnos
* remove println
* cleanup tests
* remove extraneous constraints
* use foreachStmt instead of mapStmt
* remove support for aggregates
* fold InferWidthsWithAnnos into InferWidths
* throw exception if ref not found, check for annos before AST walk
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/annotations/Target.scala | 21 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 121 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 2 |
4 files changed, 122 insertions, 28 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 9969150d..262caeea 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -44,7 +44,7 @@ class ResolveAndCheck extends CoreTransform { passes.InferTypes, passes.ResolveGenders, passes.CheckGenders, - passes.InferWidths, + new passes.InferWidths, passes.CheckWidths) } @@ -68,7 +68,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { passes.InferTypes, passes.CheckTypes, passes.ResolveGenders, - passes.InferWidths, + new passes.InferWidths, passes.CheckWidths, passes.ConvertFixedToSInt, passes.ZeroWidth, @@ -87,7 +87,7 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform { passes.ResolveKinds, passes.InferTypes, passes.ResolveGenders, - passes.InferWidths, + new passes.InferWidths, passes.Legalize, new firrtl.transforms.RemoveReset, new firrtl.transforms.CheckCombLoops, diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala index 8a9d68e8..0247b66c 100644 --- a/src/main/scala/firrtl/annotations/Target.scala +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -3,7 +3,8 @@ package firrtl package annotations -import firrtl.ir.Expression +import firrtl.ir.{Expression, Type} +import firrtl.Utils.{sub_type, field_type} import AnnotationUtils.{toExp, validComponentName, validModuleName} import TargetToken._ @@ -553,6 +554,24 @@ case class ReferenceTarget(circuit: String, /** @return The clock signal of this reference, must be to a [[firrtl.ir.DefRegister]] */ def clock: ReferenceTarget = ReferenceTarget(circuit, module, path, ref, component :+ Clock) + /** @param the type of this target's ref + * @return the type of the subcomponent specified by this target's component + */ + def componentType(baseType: Type): Type = componentType(baseType, tokens) + + private def componentType(baseType: Type, tokens: Seq[TargetToken]): Type = { + if (tokens.isEmpty) { + baseType + } else { + val headType = tokens.head match { + case Index(idx) => sub_type(baseType) + case Field(field) => field_type(baseType, field) + case _: Ref => baseType + } + componentType(headType, tokens.tail) + } + } + override def circuitOpt: Option[String] = Some(circuit) override def moduleOpt: Option[String] = Some(module) diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 6652c1fe..06833bc0 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -7,12 +7,38 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.immutable.ListMap import firrtl._ +import firrtl.annotations.{Annotation, ReferenceTarget, TargetToken} import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ -object InferWidths extends Pass { +case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarget) extends Annotation { + def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = { + val newLoc :: newExp :: Nil = Seq(loc, exp).map { target => + renameMap.get(target) match { + case None => Some(target) + case Some(Seq()) => None + case Some(Seq(one)) => Some(one) + case Some(many) => + throw new Exception(s"Target below is an AggregateType, which " + + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + } + } + + (newLoc, newExp) match { + case (Some(l: ReferenceTarget), Some(e: ReferenceTarget)) => Seq(WidthGeqConstraintAnnotation(l, e)) + case _ => Seq.empty + } + } +} + +class InferWidths extends Transform with ResolvedAnnotationPaths { + def inputForm: CircuitForm = UnknownForm + def outputForm: CircuitForm = UnknownForm + + val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) + type ConstraintMap = collection.mutable.LinkedHashMap[String, Width] def solve_constraints(l: Seq[WGeq]): ConstraintMap = { @@ -220,26 +246,26 @@ object InferWidths extends Pass { } b } - - def run (c: Circuit): Circuit = { - val v = ArrayBuffer[WGeq]() - - def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match { - case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width)) - case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width)) - case (ClockType, ClockType) => Nil - case (AsyncResetType, AsyncResetType) => Nil - case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2)) - case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1)) - case (t1: BundleType, t2: BundleType) => - (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) => - res ++ (f1.flip match { - case Default => get_constraints_t(f1.tpe, f2.tpe) - case Flip => get_constraints_t(f2.tpe, f1.tpe) - }) - } - case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe) - } + + def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match { + case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width)) + case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width)) + case (ClockType, ClockType) => Nil + case (AsyncResetType, AsyncResetType) => Nil + case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2)) + case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1)) + case (t1: BundleType, t2: BundleType) => + (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) => + res ++ (f1.flip match { + case Default => get_constraints_t(f1.tpe, f2.tpe) + case Flip => get_constraints_t(f2.tpe, f1.tpe) + }) + } + case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe) + } + + def run(c: Circuit, extra: Seq[WGeq]): Circuit = { + val v = ArrayBuffer[WGeq]() ++ extra def get_constraints_e(e: Expression): Unit = { e match { @@ -364,10 +390,59 @@ object InferWidths extends Pass { def reduce_var_widths_p(p: Port): Port = { Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe)) - } - + } + InferTypes.run(c.copy(modules = c.modules map (_ map reduce_var_widths_p map reduce_var_widths_s))) } + + def execute(state: CircuitState): CircuitState = { + val circuitName = state.circuit.main + val typeMap = new collection.mutable.HashMap[ReferenceTarget, Type] + + def getDeclTypes(modName: String)(stmt: Statement): Unit = { + val pairOpt = stmt match { + case w: DefWire => Some(w.name -> w.tpe) + case r: DefRegister => Some(r.name -> r.tpe) + case n: DefNode => Some(n.name -> n.value.tpe) + case i: WDefInstance => Some(i.name -> i.tpe) + case m: DefMemory => Some(m.name -> MemPortUtils.memType(m)) + case other => None + } + pairOpt.foreach { case (ref, tpe) => + typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe) + } + stmt.foreachStmt(getDeclTypes(modName)) + } + + if (state.annotations.exists(_.isInstanceOf[WidthGeqConstraintAnnotation])) { + state.circuit.modules.foreach { mod => + mod.ports.foreach { port => + typeMap += (ReferenceTarget(circuitName, mod.name, Nil, port.name, Nil) -> port.tpe) + } + mod.foreachStmt(getDeclTypes(mod.name)) + } + } + + val extraConstraints = state.annotations.flatMap { + case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => + val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target => + val baseType = typeMap.getOrElse(target.copy(component = Seq.empty), + throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint())) + val leafType = target.componentType(baseType) + if (leafType.isInstanceOf[AggregateType]) { + throw new Exception(s"Target below is an AggregateType, which " + + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + } + + leafType + } + + get_constraints_t(locType, expType) + case other => Seq.empty + } + + state.copy(circuit = run(state.circuit, extraConstraints)) + } } diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 04bfb19c..7f7f6e40 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -33,7 +33,7 @@ trait Pass extends Transform { // Error handling class PassException(message: String) extends Exception(message) -class PassExceptions(exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n")) +class PassExceptions(val exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n")) class Errors { val errors = collection.mutable.ArrayBuffer[PassException]() def append(pe: PassException) = errors.append(pe) |
