diff options
| author | Albert Chen | 2019-02-22 15:30:27 -0800 |
|---|---|---|
| committer | mergify[bot] | 2019-02-22 23:30:27 +0000 |
| commit | 5608aa8f42c1d69b59bee158d14fc6cef9b19a47 (patch) | |
| tree | 86b7bad9c5f164d12aba9f324bde223e7ff5e9f3 /src | |
| parent | 0ace0218d3151df2d102463dd682128a88ae7be6 (diff) | |
Add Width Constraints with Annotations (#956)
* refactor InferWidths to allow for extra contraints, add InferWidthsWithAnnos
* add test cases
* add ResolvedAnnotationPaths trait to InferWidthsWithAnnos
* remove println
* cleanup tests
* remove extraneous constraints
* use foreachStmt instead of mapStmt
* remove support for aggregates
* fold InferWidthsWithAnnos into InferWidths
* throw exception if ref not found, check for annos before AST walk
Diffstat (limited to 'src')
20 files changed, 478 insertions, 198 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 9969150d..262caeea 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -44,7 +44,7 @@ class ResolveAndCheck extends CoreTransform { passes.InferTypes, passes.ResolveGenders, passes.CheckGenders, - passes.InferWidths, + new passes.InferWidths, passes.CheckWidths) } @@ -68,7 +68,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { passes.InferTypes, passes.CheckTypes, passes.ResolveGenders, - passes.InferWidths, + new passes.InferWidths, passes.CheckWidths, passes.ConvertFixedToSInt, passes.ZeroWidth, @@ -87,7 +87,7 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform { passes.ResolveKinds, passes.InferTypes, passes.ResolveGenders, - passes.InferWidths, + new passes.InferWidths, passes.Legalize, new firrtl.transforms.RemoveReset, new firrtl.transforms.CheckCombLoops, diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala index 8a9d68e8..0247b66c 100644 --- a/src/main/scala/firrtl/annotations/Target.scala +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -3,7 +3,8 @@ package firrtl package annotations -import firrtl.ir.Expression +import firrtl.ir.{Expression, Type} +import firrtl.Utils.{sub_type, field_type} import AnnotationUtils.{toExp, validComponentName, validModuleName} import TargetToken._ @@ -553,6 +554,24 @@ case class ReferenceTarget(circuit: String, /** @return The clock signal of this reference, must be to a [[firrtl.ir.DefRegister]] */ def clock: ReferenceTarget = ReferenceTarget(circuit, module, path, ref, component :+ Clock) + /** @param the type of this target's ref + * @return the type of the subcomponent specified by this target's component + */ + def componentType(baseType: Type): Type = componentType(baseType, tokens) + + private def componentType(baseType: Type, tokens: Seq[TargetToken]): Type = { + if (tokens.isEmpty) { + baseType + } else { + val headType = tokens.head match { + case Index(idx) => sub_type(baseType) + case Field(field) => field_type(baseType, field) + case _: Ref => baseType + } + componentType(headType, tokens.tail) + } + } + override def circuitOpt: Option[String] = Some(circuit) override def moduleOpt: Option[String] = Some(module) diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 6652c1fe..06833bc0 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -7,12 +7,38 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.immutable.ListMap import firrtl._ +import firrtl.annotations.{Annotation, ReferenceTarget, TargetToken} import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ -object InferWidths extends Pass { +case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarget) extends Annotation { + def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = { + val newLoc :: newExp :: Nil = Seq(loc, exp).map { target => + renameMap.get(target) match { + case None => Some(target) + case Some(Seq()) => None + case Some(Seq(one)) => Some(one) + case Some(many) => + throw new Exception(s"Target below is an AggregateType, which " + + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + } + } + + (newLoc, newExp) match { + case (Some(l: ReferenceTarget), Some(e: ReferenceTarget)) => Seq(WidthGeqConstraintAnnotation(l, e)) + case _ => Seq.empty + } + } +} + +class InferWidths extends Transform with ResolvedAnnotationPaths { + def inputForm: CircuitForm = UnknownForm + def outputForm: CircuitForm = UnknownForm + + val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) + type ConstraintMap = collection.mutable.LinkedHashMap[String, Width] def solve_constraints(l: Seq[WGeq]): ConstraintMap = { @@ -220,26 +246,26 @@ object InferWidths extends Pass { } b } - - def run (c: Circuit): Circuit = { - val v = ArrayBuffer[WGeq]() - - def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match { - case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width)) - case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width)) - case (ClockType, ClockType) => Nil - case (AsyncResetType, AsyncResetType) => Nil - case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2)) - case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1)) - case (t1: BundleType, t2: BundleType) => - (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) => - res ++ (f1.flip match { - case Default => get_constraints_t(f1.tpe, f2.tpe) - case Flip => get_constraints_t(f2.tpe, f1.tpe) - }) - } - case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe) - } + + def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match { + case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width)) + case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width)) + case (ClockType, ClockType) => Nil + case (AsyncResetType, AsyncResetType) => Nil + case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2)) + case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1)) + case (t1: BundleType, t2: BundleType) => + (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) => + res ++ (f1.flip match { + case Default => get_constraints_t(f1.tpe, f2.tpe) + case Flip => get_constraints_t(f2.tpe, f1.tpe) + }) + } + case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe) + } + + def run(c: Circuit, extra: Seq[WGeq]): Circuit = { + val v = ArrayBuffer[WGeq]() ++ extra def get_constraints_e(e: Expression): Unit = { e match { @@ -364,10 +390,59 @@ object InferWidths extends Pass { def reduce_var_widths_p(p: Port): Port = { Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe)) - } - + } + InferTypes.run(c.copy(modules = c.modules map (_ map reduce_var_widths_p map reduce_var_widths_s))) } + + def execute(state: CircuitState): CircuitState = { + val circuitName = state.circuit.main + val typeMap = new collection.mutable.HashMap[ReferenceTarget, Type] + + def getDeclTypes(modName: String)(stmt: Statement): Unit = { + val pairOpt = stmt match { + case w: DefWire => Some(w.name -> w.tpe) + case r: DefRegister => Some(r.name -> r.tpe) + case n: DefNode => Some(n.name -> n.value.tpe) + case i: WDefInstance => Some(i.name -> i.tpe) + case m: DefMemory => Some(m.name -> MemPortUtils.memType(m)) + case other => None + } + pairOpt.foreach { case (ref, tpe) => + typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe) + } + stmt.foreachStmt(getDeclTypes(modName)) + } + + if (state.annotations.exists(_.isInstanceOf[WidthGeqConstraintAnnotation])) { + state.circuit.modules.foreach { mod => + mod.ports.foreach { port => + typeMap += (ReferenceTarget(circuitName, mod.name, Nil, port.name, Nil) -> port.tpe) + } + mod.foreachStmt(getDeclTypes(mod.name)) + } + } + + val extraConstraints = state.annotations.flatMap { + case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => + val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target => + val baseType = typeMap.getOrElse(target.copy(component = Seq.empty), + throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint())) + val leafType = target.componentType(baseType) + if (leafType.isInstanceOf[AggregateType]) { + throw new Exception(s"Target below is an AggregateType, which " + + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + } + + leafType + } + + get_constraints_t(locType, expType) + case other => Seq.empty + } + + state.copy(circuit = run(state.circuit, extraConstraints)) + } } diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 04bfb19c..7f7f6e40 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -33,7 +33,7 @@ trait Pass extends Transform { // Error handling class PassException(message: String) extends Exception(message) -class PassExceptions(exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n")) +class PassExceptions(val exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n")) class Errors { val errors = collection.mutable.ArrayBuffer[PassException]() def append(pe: PassException) = errors.append(pe) diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala index 9bf5fefd..1acb0d8b 100644 --- a/src/test/scala/firrtlTests/AttachSpec.scala +++ b/src/test/scala/firrtlTests/AttachSpec.scala @@ -411,7 +411,7 @@ class AttachAnalogSpec extends FirrtlFlatSpec { ResolveKinds, InferTypes, CheckTypes, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -422,8 +422,8 @@ class AttachAnalogSpec extends FirrtlFlatSpec { | extmodule A : | output o: Analog<2> """.stripMargin intercept[CheckWidths.AttachWidthsNotEqual] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } } diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala index ef966ca0..9ccff256 100644 --- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala +++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala @@ -5,8 +5,7 @@ package firrtlTests import java.io._ import org.scalatest._ import org.scalatest.prop._ -import firrtl.Parser -import firrtl.ir.Circuit +import firrtl.{Parser, CircuitState, UnknownForm, Transform} import firrtl.Parser.IgnoreInfo import firrtl.passes._ @@ -19,7 +18,7 @@ class CheckInitializationSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, PullMuxes, ExpandConnects, @@ -35,8 +34,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | when p : | x <= UInt(1)""".stripMargin intercept[CheckInitialization.RefNotInitializedException] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } } @@ -50,8 +49,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | else : | x <= UInt(1)""".stripMargin intercept[CheckInitialization.RefNotInitializedException] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } } @@ -66,8 +65,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | when p : | c.in <= UInt(1)""".stripMargin intercept[CheckInitialization.RefNotInitializedException] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } } diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala index 767f2392..3e6b19f9 100644 --- a/src/test/scala/firrtlTests/CheckSpec.scala +++ b/src/test/scala/firrtlTests/CheckSpec.scala @@ -5,7 +5,7 @@ package firrtlTests import java.io._ import org.scalatest._ import org.scalatest.prop._ -import firrtl.Parser +import firrtl.{Parser, CircuitState, UnknownForm, Transform} import firrtl.ir.Circuit import firrtl.passes.{Pass,ToWorkingIR,CheckHighForm,ResolveKinds,InferTypes,CheckTypes,PassException,InferWidths,CheckWidths,ResolveGenders,CheckGenders} @@ -153,7 +153,7 @@ class CheckSpec extends FlatSpec with Matchers { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """ @@ -180,8 +180,8 @@ class CheckSpec extends FlatSpec with Matchers { | sub.io.debug_clk <= io.jtag.TCK | |""".stripMargin - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala index 774c352b..9344b861 100644 --- a/src/test/scala/firrtlTests/ChirrtlSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala @@ -5,7 +5,7 @@ package firrtlTests import java.io._ import org.scalatest._ import org.scalatest.prop._ -import firrtl.Parser +import firrtl.{Parser, CircuitState, UnknownForm, Transform} import firrtl.ir.Circuit import firrtl.passes._ import firrtl._ @@ -23,7 +23,7 @@ class ChirrtlSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, PullMuxes, ExpandConnects, diff --git a/src/test/scala/firrtlTests/ClockListTests.scala b/src/test/scala/firrtlTests/ClockListTests.scala index 48d6dfd3..dde719d5 100644 --- a/src/test/scala/firrtlTests/ClockListTests.scala +++ b/src/test/scala/firrtlTests/ClockListTests.scala @@ -25,7 +25,7 @@ class ClockListTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths + new InferWidths ) "Getting clock list" should "work" in { @@ -75,9 +75,9 @@ class ClockListTests extends FirrtlFlatSpec { |Good Origin of h$b.clock is h$clkGen.clk2 |Good Origin of h$c.clock is h$clkGen.clk3 |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + }.circuit val writer = new StringWriter() val retC = new ClockList("HTop", writer).run(c) (writer.toString) should be (check) @@ -106,9 +106,9 @@ class ClockListTests extends FirrtlFlatSpec { |Good Origin of b.clock is clkB |Good Origin of b$c.clock is clock |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + }.circuit val writer = new StringWriter() val retC = new ClockList("A", writer).run(c) (writer.toString) should be (check) @@ -139,9 +139,9 @@ class ClockListTests extends FirrtlFlatSpec { |Good Origin of clock is clock |Good Origin of c.clock is clkC |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + }.circuit val writer = new StringWriter() val retC = new ClockList("B", writer).run(c) (writer.toString) should be (check) diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 6fb2ab8d..da3e9b41 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -14,7 +14,7 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, new ConstantPropagation) protected def exec(input: String) = { transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala index 3532ce00..a1ac8a31 100644 --- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala +++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala @@ -22,7 +22,7 @@ class ExpandWhensSpec extends FirrtlFlatSpec { InferTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, PullMuxes, ExpandConnects, diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index ab367554..27f2c8a0 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -20,7 +20,7 @@ class LowerTypesSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, PullMuxes, ExpandConnects, @@ -32,7 +32,7 @@ class LowerTypesSpec extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, LowerTypes) private def executeTest(input: String, expected: Seq[String]) = { diff --git a/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala b/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala index e507e947..64977c7f 100644 --- a/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala +++ b/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala @@ -14,7 +14,7 @@ class ReplaceAccessesSpec extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, ReplaceAccesses) protected def exec(input: String) = { transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index b3af920c..0ef4f709 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -158,7 +158,7 @@ class UnitTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, SplitExpressions ) val input = @@ -182,7 +182,7 @@ class UnitTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, PadWidths ) val input = @@ -203,7 +203,7 @@ class UnitTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, PullMuxes, ExpandConnects, RemoveAccesses, @@ -243,15 +243,15 @@ class UnitTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : | node x = bits(UInt(1), 100, 0)""".stripMargin intercept[CheckWidths.BitsWidthException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } } @@ -262,15 +262,15 @@ class UnitTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : | node x = head(UInt(1), 100)""".stripMargin intercept[CheckWidths.HeadWidthException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } } @@ -281,15 +281,15 @@ class UnitTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : | node x = tail(UInt(1), 100)""".stripMargin intercept[CheckWidths.TailWidthException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } } @@ -308,8 +308,8 @@ class UnitTests extends FirrtlFlatSpec { | bar <- foo |""".stripMargin intercept[PassException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } } @@ -389,7 +389,7 @@ class UnitTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, PullMuxes, ExpandConnects, RemoveAccesses, diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala index 058cc1fa..96bd249c 100644 --- a/src/test/scala/firrtlTests/WidthSpec.scala +++ b/src/test/scala/firrtlTests/WidthSpec.scala @@ -11,10 +11,10 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class WidthSpec extends FirrtlFlatSpec { - private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { + val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + }.circuit val lines = c.serialize.split("\n") map normalized expected foreach { e => @@ -54,7 +54,7 @@ class WidthSpec extends FirrtlFlatSpec { InferTypes, CheckTypes, ResolveGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -77,7 +77,7 @@ class WidthSpec extends FirrtlFlatSpec { InferTypes, CheckTypes, ResolveGenders, - InferWidths, + new InferWidths, CheckWidths) val input = s"""circuit Unit : @@ -96,7 +96,7 @@ class WidthSpec extends FirrtlFlatSpec { InferTypes, CheckTypes, ResolveGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -121,7 +121,7 @@ class WidthSpec extends FirrtlFlatSpec { InferTypes, CheckTypes, ResolveGenders, - InferWidths) + new InferWidths) val input = """circuit Unit : | module Unit : @@ -143,7 +143,7 @@ class WidthSpec extends FirrtlFlatSpec { InferTypes, CheckTypes, ResolveGenders, - InferWidths) + new InferWidths) val input = """circuit Unit : | module Unit : @@ -166,7 +166,7 @@ class WidthSpec extends FirrtlFlatSpec { InferTypes, CheckTypes, ResolveGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """|circuit Foo : diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala index 4f8fd9fe..4fe4a46c 100644 --- a/src/test/scala/firrtlTests/WiringTests.scala +++ b/src/test/scala/firrtlTests/WiringTests.scala @@ -14,15 +14,19 @@ import wiring.WiringUtils._ import wiring._ class WiringTests extends FirrtlFlatSpec { - private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } - val lines = c.serialize.split("\n") map normalized + private def executeTest(input: String, + expected: String, + passes: Seq[Transform], + annos: Seq[Annotation]): Unit = { + val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm, annos)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + }.circuit - expected foreach { e => - lines should contain(e) - } + (parse(c.serialize).serialize) should be (parse(expected).serialize) + } + + private def executeTest(input: String, expected: String, passes: Seq[Transform]): Unit = { + executeTest(input, expected, passes, Seq.empty) } def passes = Seq( @@ -30,7 +34,7 @@ class WiringTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths + new InferWidths ) it should "wire from a register source (r) to multiple extmodule sinks (X)" in { @@ -114,12 +118,9 @@ class WiringTests extends FirrtlFlatSpec { | input clock: Clock | input pin: UInt<5> |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "wire from a register source (r) to multiple module sinks (X)" in { @@ -203,12 +204,9 @@ class WiringTests extends FirrtlFlatSpec { | input clock: Clock | input pin: UInt<5> |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "wire from a register sink (r) to a wire source (s) in another module (X)" in { @@ -295,12 +293,9 @@ class WiringTests extends FirrtlFlatSpec { | wire s: UInt<5> | s <= pin |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "wire from a SubField source (r.x) to an extmodule sink (X)" in { @@ -339,12 +334,9 @@ class WiringTests extends FirrtlFlatSpec { | input clock: Clock | input pin: UInt<5> |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "wire properly with a source as a submodule of a sink" in { @@ -386,12 +378,9 @@ class WiringTests extends FirrtlFlatSpec { | reg r: UInt<5>, clock | r_0 <= r |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "wire with source and sink in the same module" in { @@ -415,12 +404,9 @@ class WiringTests extends FirrtlFlatSpec { | s <= pin | pin <= r |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "wire multiple sinks in the same module" in { @@ -456,12 +442,9 @@ class WiringTests extends FirrtlFlatSpec { | s <= pin | pin <= r |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "wire clocks" in { @@ -498,12 +481,9 @@ class WiringTests extends FirrtlFlatSpec { | input clock: Clock | input pin: Clock |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "handle two source instances with clearly defined sinks" in { @@ -544,12 +524,9 @@ class WiringTests extends FirrtlFlatSpec { | input clock: Clock | input pin: Clock |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "wire multiple clocks" in { @@ -590,12 +567,9 @@ class WiringTests extends FirrtlFlatSpec { | input clock: Clock | input pin: Clock |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "error with WiringException for indeterminate ownership" in { @@ -619,12 +593,10 @@ class WiringTests extends FirrtlFlatSpec { | extmodule X : | input clock: Clock |""".stripMargin + intercept[WiringException] { - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) + executeTest(input, "", passes :+ wiringPass) } } @@ -666,12 +638,9 @@ class WiringTests extends FirrtlFlatSpec { | input clock: Clock | input pin: UInt<2> |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(Seq(sas)) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } it should "wire using Annotations with a sink module" in { @@ -701,12 +670,9 @@ class WiringTests extends FirrtlFlatSpec { | input clk: Clock | input pin: UInt<5> |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringXForm = new WiringTransform() - val retC = wiringXForm.execute(CircuitState(c, MidForm, Seq(source, sink))).circuit - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringXForm, Seq(source, sink)) } it should "wire using Annotations with a sink component" in { @@ -739,12 +705,9 @@ class WiringTests extends FirrtlFlatSpec { | wire s: UInt<5> | s <= pin |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringXForm = new WiringTransform() - val retC = wiringXForm.execute(CircuitState(c, MidForm, Seq(source, sink))).circuit - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringXForm, Seq(source, sink)) } it should "wire using annotations with Aggregate source" in { @@ -785,12 +748,9 @@ class WiringTests extends FirrtlFlatSpec { | input clock : Clock | input pin : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }""" .stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringXForm = new WiringTransform() - val retC = wiringXForm.execute(CircuitState(c, MidForm, Seq(source, sink))).circuit - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringXForm, Seq(source, sink)) } it should "wire one sink to multiple, disjoint extmodules" in { @@ -845,11 +805,8 @@ class WiringTests extends FirrtlFlatSpec { | input clock: Clock | input pin: UInt<5> |""".stripMargin - val c = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - } + val wiringPass = new Wiring(wiSeq) - val retC = wiringPass.run(c) - (parse(retC.serialize).serialize) should be (parse(check).serialize) + executeTest(input, check, passes :+ wiringPass) } } diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala index 6443e131..eb955f29 100644 --- a/src/test/scala/firrtlTests/ZeroWidthTests.scala +++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala @@ -16,7 +16,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { ResolveKinds, InferTypes, ResolveGenders, - InferWidths, + new InferWidths, ZeroWidth) private def exec (input: String) = { val circuit = parse(input) diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala index a866836f..667db7b0 100644 --- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala @@ -10,10 +10,10 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class FixedTypeInferenceSpec extends FirrtlFlatSpec { - private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { + val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + }.circuit val lines = c.serialize.split("\n") map normalized expected foreach { e => @@ -30,7 +30,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -60,7 +60,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -86,7 +86,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -112,7 +112,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -138,7 +138,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -164,7 +164,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -190,7 +190,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -230,7 +230,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths) val input = """circuit Unit : @@ -256,7 +256,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, ConvertFixedToSInt) val input = diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala index 8645fa62..21a39e83 100644 --- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala +++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala @@ -9,10 +9,10 @@ import firrtl.passes._ import firrtl.Parser.IgnoreInfo class RemoveFixedTypeSpec extends FirrtlFlatSpec { - private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { + val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + }.circuit val lines = c.serialize.split("\n") map normalized println(c.serialize) @@ -30,7 +30,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, ConvertFixedToSInt) val input = @@ -60,7 +60,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, ConvertFixedToSInt) val input = @@ -91,7 +91,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, ConvertFixedToSInt) val input = @@ -118,7 +118,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, ConvertFixedToSInt) val input = @@ -145,7 +145,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, ConvertFixedToSInt) val input = @@ -197,7 +197,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckTypes, ResolveGenders, CheckGenders, - InferWidths, + new InferWidths, CheckWidths, ConvertFixedToSInt) val input = diff --git a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala new file mode 100644 index 00000000..54a3df40 --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala @@ -0,0 +1,230 @@ +// See LICENSE for license details. + +package firrtlTests.transforms + +import firrtlTests.FirrtlFlatSpec +import org.scalatest._ +import org.scalatest.prop._ +import firrtl._ +import firrtl.passes._ +import firrtl.passes.wiring.{WiringTransform, SourceAnnotation, SinkAnnotation} +import firrtl.ir.Circuit +import firrtl.annotations._ +import firrtl.annotations.TargetToken.{Field, Index} + + +class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { + private def executeTest(input: String, + check: String, + transforms: Seq[Transform], + annotations: Seq[Annotation]) = { + val start = CircuitState(parse(input), ChirrtlForm, annotations) + val end = transforms.foldLeft(start) { + (c: CircuitState, t: Transform) => t.runTransform(c) + } + val resLines = end.circuit.serialize.split("\n") map normalized + val checkLines = parse(check).serialize.split("\n") map normalized + + resLines should be (checkLines) + } + + "CheckWidths on wires with unknown widths" should "result in an error" in { + val transforms = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + new InferWidths, + CheckWidths) + + val input = + """circuit Top : + | module Top : + | inst b of B + | inst a of A + | + | module B : + | wire x: UInt<3> + | + | module A : + | wire y: UInt""".stripMargin + + // A.y should have uninferred width + intercept[CheckWidths.UninferredWidth] { + executeTest(input, "", transforms, Seq.empty) + } + } + + "InferWidthsWithAnnos" should "infer widths using WidthGeqConstraintAnnotation" in { + val transforms = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + new InferWidths, + CheckWidths) + + val annos = Seq(WidthGeqConstraintAnnotation( + ReferenceTarget("Top", "A", Nil, "y", Nil), + ReferenceTarget("Top", "B", Nil, "x", Nil))) + + val input = + """circuit Top : + | module Top : + | inst b of B + | inst a of A + | + | module B : + | wire x: UInt<3> + | + | module A : + | wire y: UInt""".stripMargin + + val output = + """circuit Top : + | module Top : + | inst b of B + | inst a of A + | + | module B : + | wire x: UInt<3> + | + | module A : + | wire y: UInt<3>""".stripMargin + + // A.y should have same width as B.x + executeTest(input, output, transforms, annos) + } + + "InferWidthsWithAnnos" should "work with token paths" in { + val transforms = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + new InferWidths, + CheckWidths) + + val tokenLists = Seq( + Seq(Field("x")), + Seq(Field("y"), Index(0), Field("yy")), + Seq(Field("y"), Index(1), Field("yy")) + ) + + val annos = tokenLists.map { tokens => + WidthGeqConstraintAnnotation( + ReferenceTarget("Top", "A", Nil, "bundle", tokens), + ReferenceTarget("Top", "B", Nil, "bundle", tokens)) + } + + val input = + """circuit Top : + | module Top : + | inst b of B + | inst a of A + | + | module B : + | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] } + | + | module A : + | wire bundle : {x : UInt, y: {yy : UInt}[2] }""".stripMargin + + val output = + """circuit Top : + | module Top : + | inst b of B + | inst a of A + | + | module B : + | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] } + | + | module A : + | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] }""".stripMargin + + // elements of A.bundle should have same width as B.bundle + executeTest(input, output, transforms, annos) + } + + "InferWidthsWithAnnos" should "work with WiringTransform" in { + def transforms = Seq( + ToWorkingIR, + ResolveKinds, + InferTypes, + ResolveGenders, + new InferWidths, + CheckWidths, + new WiringTransform, + new ResolveAndCheck + ) + val sourceTarget = ComponentName("bundle", ModuleName("A", CircuitName("Top"))) + val source = SourceAnnotation(sourceTarget, "pin") + + val sinkTarget = ComponentName("bundle", ModuleName("B", CircuitName("Top"))) + val sink = SinkAnnotation(sinkTarget, "pin") + + val tokenLists = Seq( + Seq(Field("x")), + Seq(Field("y"), Index(0), Field("yy")), + Seq(Field("y"), Index(1), Field("yy")) + ) + + val wgeqAnnos = tokenLists.map { tokens => + WidthGeqConstraintAnnotation( + ReferenceTarget("Top", "A", Nil, "bundle", tokens), + ReferenceTarget("Top", "B", Nil, "bundle", tokens)) + } + + val failAnnos = Seq(source, sink) + val successAnnos = wgeqAnnos ++: failAnnos + + val input = + """circuit Top : + | module Top : + | inst b of B + | inst a of A + | + | module B : + | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] } + | + | module A : + | wire bundle : {x : UInt, y: {yy : UInt}[2] }""".stripMargin + + val output = + """circuit Top : + | module Top : + | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] } + | inst b of B + | inst a of A + | b.pin <= bundle + | bundle <= a.bundle_0 + | + | module B : + | input pin : {x : UInt<1>, y: {yy : UInt<3>}[2] } + | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] } + | bundle <= pin + | + | module A : + | output bundle_0 : {x : UInt<1>, y: {yy : UInt<3>}[2] } + | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] } + | bundle_0 <= bundle""" + .stripMargin + + // should fail without extra constraint annos due to UninferredWidths + val exceptions = intercept[PassExceptions] { + executeTest(input, "", transforms, failAnnos) + }.exceptions.reverse + + val msg = exceptions.head.toString + assert(msg.contains(s"2 errors detected!")) + assert(exceptions.tail.forall(_.isInstanceOf[CheckWidths.UninferredWidth])) + + // should pass with extra constraints + executeTest(input, output, transforms, successAnnos) + } +} |
