aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Chen2019-02-22 15:30:27 -0800
committermergify[bot]2019-02-22 23:30:27 +0000
commit5608aa8f42c1d69b59bee158d14fc6cef9b19a47 (patch)
tree86b7bad9c5f164d12aba9f324bde223e7ff5e9f3 /src
parent0ace0218d3151df2d102463dd682128a88ae7be6 (diff)
Add Width Constraints with Annotations (#956)
* refactor InferWidths to allow for extra contraints, add InferWidthsWithAnnos * add test cases * add ResolvedAnnotationPaths trait to InferWidthsWithAnnos * remove println * cleanup tests * remove extraneous constraints * use foreachStmt instead of mapStmt * remove support for aggregates * fold InferWidthsWithAnnos into InferWidths * throw exception if ref not found, check for annos before AST walk
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala6
-rw-r--r--src/main/scala/firrtl/annotations/Target.scala21
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala121
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala2
-rw-r--r--src/test/scala/firrtlTests/AttachSpec.scala6
-rw-r--r--src/test/scala/firrtlTests/CheckInitializationSpec.scala17
-rw-r--r--src/test/scala/firrtlTests/CheckSpec.scala8
-rw-r--r--src/test/scala/firrtlTests/ChirrtlSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/ClockListTests.scala20
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala2
-rw-r--r--src/test/scala/firrtlTests/ExpandWhensSpec.scala2
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/ReplaceAccessesSpec.scala2
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala30
-rw-r--r--src/test/scala/firrtlTests/WidthSpec.scala20
-rw-r--r--src/test/scala/firrtlTests/WiringTests.scala133
-rw-r--r--src/test/scala/firrtlTests/ZeroWidthTests.scala2
-rw-r--r--src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala26
-rw-r--r--src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala20
-rw-r--r--src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala230
20 files changed, 478 insertions, 198 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 9969150d..262caeea 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -44,7 +44,7 @@ class ResolveAndCheck extends CoreTransform {
passes.InferTypes,
passes.ResolveGenders,
passes.CheckGenders,
- passes.InferWidths,
+ new passes.InferWidths,
passes.CheckWidths)
}
@@ -68,7 +68,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform {
passes.InferTypes,
passes.CheckTypes,
passes.ResolveGenders,
- passes.InferWidths,
+ new passes.InferWidths,
passes.CheckWidths,
passes.ConvertFixedToSInt,
passes.ZeroWidth,
@@ -87,7 +87,7 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform {
passes.ResolveKinds,
passes.InferTypes,
passes.ResolveGenders,
- passes.InferWidths,
+ new passes.InferWidths,
passes.Legalize,
new firrtl.transforms.RemoveReset,
new firrtl.transforms.CheckCombLoops,
diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala
index 8a9d68e8..0247b66c 100644
--- a/src/main/scala/firrtl/annotations/Target.scala
+++ b/src/main/scala/firrtl/annotations/Target.scala
@@ -3,7 +3,8 @@
package firrtl
package annotations
-import firrtl.ir.Expression
+import firrtl.ir.{Expression, Type}
+import firrtl.Utils.{sub_type, field_type}
import AnnotationUtils.{toExp, validComponentName, validModuleName}
import TargetToken._
@@ -553,6 +554,24 @@ case class ReferenceTarget(circuit: String,
/** @return The clock signal of this reference, must be to a [[firrtl.ir.DefRegister]] */
def clock: ReferenceTarget = ReferenceTarget(circuit, module, path, ref, component :+ Clock)
+ /** @param the type of this target's ref
+ * @return the type of the subcomponent specified by this target's component
+ */
+ def componentType(baseType: Type): Type = componentType(baseType, tokens)
+
+ private def componentType(baseType: Type, tokens: Seq[TargetToken]): Type = {
+ if (tokens.isEmpty) {
+ baseType
+ } else {
+ val headType = tokens.head match {
+ case Index(idx) => sub_type(baseType)
+ case Field(field) => field_type(baseType, field)
+ case _: Ref => baseType
+ }
+ componentType(headType, tokens.tail)
+ }
+ }
+
override def circuitOpt: Option[String] = Some(circuit)
override def moduleOpt: Option[String] = Some(module)
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 6652c1fe..06833bc0 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -7,12 +7,38 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.immutable.ListMap
import firrtl._
+import firrtl.annotations.{Annotation, ReferenceTarget, TargetToken}
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
-object InferWidths extends Pass {
+case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarget) extends Annotation {
+ def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = {
+ val newLoc :: newExp :: Nil = Seq(loc, exp).map { target =>
+ renameMap.get(target) match {
+ case None => Some(target)
+ case Some(Seq()) => None
+ case Some(Seq(one)) => Some(one)
+ case Some(many) =>
+ throw new Exception(s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ }
+ }
+
+ (newLoc, newExp) match {
+ case (Some(l: ReferenceTarget), Some(e: ReferenceTarget)) => Seq(WidthGeqConstraintAnnotation(l, e))
+ case _ => Seq.empty
+ }
+ }
+}
+
+class InferWidths extends Transform with ResolvedAnnotationPaths {
+ def inputForm: CircuitForm = UnknownForm
+ def outputForm: CircuitForm = UnknownForm
+
+ val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
+
type ConstraintMap = collection.mutable.LinkedHashMap[String, Width]
def solve_constraints(l: Seq[WGeq]): ConstraintMap = {
@@ -220,26 +246,26 @@ object InferWidths extends Pass {
}
b
}
-
- def run (c: Circuit): Circuit = {
- val v = ArrayBuffer[WGeq]()
-
- def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
- case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
- case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
- case (ClockType, ClockType) => Nil
- case (AsyncResetType, AsyncResetType) => Nil
- case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
- case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1))
- case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
- res ++ (f1.flip match {
- case Default => get_constraints_t(f1.tpe, f2.tpe)
- case Flip => get_constraints_t(f2.tpe, f1.tpe)
- })
- }
- case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
- }
+
+ def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
+ case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
+ case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
+ case (ClockType, ClockType) => Nil
+ case (AsyncResetType, AsyncResetType) => Nil
+ case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
+ case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1))
+ case (t1: BundleType, t2: BundleType) =>
+ (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
+ res ++ (f1.flip match {
+ case Default => get_constraints_t(f1.tpe, f2.tpe)
+ case Flip => get_constraints_t(f2.tpe, f1.tpe)
+ })
+ }
+ case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
+ }
+
+ def run(c: Circuit, extra: Seq[WGeq]): Circuit = {
+ val v = ArrayBuffer[WGeq]() ++ extra
def get_constraints_e(e: Expression): Unit = {
e match {
@@ -364,10 +390,59 @@ object InferWidths extends Pass {
def reduce_var_widths_p(p: Port): Port = {
Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe))
- }
-
+ }
+
InferTypes.run(c.copy(modules = c.modules map (_
map reduce_var_widths_p
map reduce_var_widths_s)))
}
+
+ def execute(state: CircuitState): CircuitState = {
+ val circuitName = state.circuit.main
+ val typeMap = new collection.mutable.HashMap[ReferenceTarget, Type]
+
+ def getDeclTypes(modName: String)(stmt: Statement): Unit = {
+ val pairOpt = stmt match {
+ case w: DefWire => Some(w.name -> w.tpe)
+ case r: DefRegister => Some(r.name -> r.tpe)
+ case n: DefNode => Some(n.name -> n.value.tpe)
+ case i: WDefInstance => Some(i.name -> i.tpe)
+ case m: DefMemory => Some(m.name -> MemPortUtils.memType(m))
+ case other => None
+ }
+ pairOpt.foreach { case (ref, tpe) =>
+ typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe)
+ }
+ stmt.foreachStmt(getDeclTypes(modName))
+ }
+
+ if (state.annotations.exists(_.isInstanceOf[WidthGeqConstraintAnnotation])) {
+ state.circuit.modules.foreach { mod =>
+ mod.ports.foreach { port =>
+ typeMap += (ReferenceTarget(circuitName, mod.name, Nil, port.name, Nil) -> port.tpe)
+ }
+ mod.foreachStmt(getDeclTypes(mod.name))
+ }
+ }
+
+ val extraConstraints = state.annotations.flatMap {
+ case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
+ val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target =>
+ val baseType = typeMap.getOrElse(target.copy(component = Seq.empty),
+ throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint()))
+ val leafType = target.componentType(baseType)
+ if (leafType.isInstanceOf[AggregateType]) {
+ throw new Exception(s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ }
+
+ leafType
+ }
+
+ get_constraints_t(locType, expType)
+ case other => Seq.empty
+ }
+
+ state.copy(circuit = run(state.circuit, extraConstraints))
+ }
}
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 04bfb19c..7f7f6e40 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -33,7 +33,7 @@ trait Pass extends Transform {
// Error handling
class PassException(message: String) extends Exception(message)
-class PassExceptions(exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n"))
+class PassExceptions(val exceptions: Seq[PassException]) extends Exception("\n" + exceptions.mkString("\n"))
class Errors {
val errors = collection.mutable.ArrayBuffer[PassException]()
def append(pe: PassException) = errors.append(pe)
diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala
index 9bf5fefd..1acb0d8b 100644
--- a/src/test/scala/firrtlTests/AttachSpec.scala
+++ b/src/test/scala/firrtlTests/AttachSpec.scala
@@ -411,7 +411,7 @@ class AttachAnalogSpec extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
CheckTypes,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -422,8 +422,8 @@ class AttachAnalogSpec extends FirrtlFlatSpec {
| extmodule A :
| output o: Analog<2> """.stripMargin
intercept[CheckWidths.AttachWidthsNotEqual] {
- passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
+ passes.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
}
diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala
index ef966ca0..9ccff256 100644
--- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala
+++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala
@@ -5,8 +5,7 @@ package firrtlTests
import java.io._
import org.scalatest._
import org.scalatest.prop._
-import firrtl.Parser
-import firrtl.ir.Circuit
+import firrtl.{Parser, CircuitState, UnknownForm, Transform}
import firrtl.Parser.IgnoreInfo
import firrtl.passes._
@@ -19,7 +18,7 @@ class CheckInitializationSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
PullMuxes,
ExpandConnects,
@@ -35,8 +34,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec {
| when p :
| x <= UInt(1)""".stripMargin
intercept[CheckInitialization.RefNotInitializedException] {
- passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
+ passes.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
}
@@ -50,8 +49,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec {
| else :
| x <= UInt(1)""".stripMargin
intercept[CheckInitialization.RefNotInitializedException] {
- passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
+ passes.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
}
@@ -66,8 +65,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec {
| when p :
| c.in <= UInt(1)""".stripMargin
intercept[CheckInitialization.RefNotInitializedException] {
- passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
+ passes.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
}
diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala
index 767f2392..3e6b19f9 100644
--- a/src/test/scala/firrtlTests/CheckSpec.scala
+++ b/src/test/scala/firrtlTests/CheckSpec.scala
@@ -5,7 +5,7 @@ package firrtlTests
import java.io._
import org.scalatest._
import org.scalatest.prop._
-import firrtl.Parser
+import firrtl.{Parser, CircuitState, UnknownForm, Transform}
import firrtl.ir.Circuit
import firrtl.passes.{Pass,ToWorkingIR,CheckHighForm,ResolveKinds,InferTypes,CheckTypes,PassException,InferWidths,CheckWidths,ResolveGenders,CheckGenders}
@@ -153,7 +153,7 @@ class CheckSpec extends FlatSpec with Matchers {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""
@@ -180,8 +180,8 @@ class CheckSpec extends FlatSpec with Matchers {
| sub.io.debug_clk <= io.jtag.TCK
|
|""".stripMargin
- passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala
index 774c352b..9344b861 100644
--- a/src/test/scala/firrtlTests/ChirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala
@@ -5,7 +5,7 @@ package firrtlTests
import java.io._
import org.scalatest._
import org.scalatest.prop._
-import firrtl.Parser
+import firrtl.{Parser, CircuitState, UnknownForm, Transform}
import firrtl.ir.Circuit
import firrtl.passes._
import firrtl._
@@ -23,7 +23,7 @@ class ChirrtlSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
PullMuxes,
ExpandConnects,
diff --git a/src/test/scala/firrtlTests/ClockListTests.scala b/src/test/scala/firrtlTests/ClockListTests.scala
index 48d6dfd3..dde719d5 100644
--- a/src/test/scala/firrtlTests/ClockListTests.scala
+++ b/src/test/scala/firrtlTests/ClockListTests.scala
@@ -25,7 +25,7 @@ class ClockListTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths
+ new InferWidths
)
"Getting clock list" should "work" in {
@@ -75,9 +75,9 @@ class ClockListTests extends FirrtlFlatSpec {
|Good Origin of h$b.clock is h$clkGen.clk2
|Good Origin of h$c.clock is h$clkGen.clk3
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+ val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
+ }.circuit
val writer = new StringWriter()
val retC = new ClockList("HTop", writer).run(c)
(writer.toString) should be (check)
@@ -106,9 +106,9 @@ class ClockListTests extends FirrtlFlatSpec {
|Good Origin of b.clock is clkB
|Good Origin of b$c.clock is clock
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+ val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
+ }.circuit
val writer = new StringWriter()
val retC = new ClockList("A", writer).run(c)
(writer.toString) should be (check)
@@ -139,9 +139,9 @@ class ClockListTests extends FirrtlFlatSpec {
|Good Origin of clock is clock
|Good Origin of c.clock is clkC
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+ val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
+ }.circuit
val writer = new StringWriter()
val retC = new ClockList("B", writer).run(c)
(writer.toString) should be (check)
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 6fb2ab8d..da3e9b41 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -14,7 +14,7 @@ class ConstantPropagationSpec extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
new ConstantPropagation)
protected def exec(input: String) = {
transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala
index 3532ce00..a1ac8a31 100644
--- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala
+++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala
@@ -22,7 +22,7 @@ class ExpandWhensSpec extends FirrtlFlatSpec {
InferTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
PullMuxes,
ExpandConnects,
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index ab367554..27f2c8a0 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -20,7 +20,7 @@ class LowerTypesSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
PullMuxes,
ExpandConnects,
@@ -32,7 +32,7 @@ class LowerTypesSpec extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
LowerTypes)
private def executeTest(input: String, expected: Seq[String]) = {
diff --git a/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala b/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala
index e507e947..64977c7f 100644
--- a/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala
+++ b/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala
@@ -14,7 +14,7 @@ class ReplaceAccessesSpec extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
ReplaceAccesses)
protected def exec(input: String) = {
transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index b3af920c..0ef4f709 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -158,7 +158,7 @@ class UnitTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
SplitExpressions
)
val input =
@@ -182,7 +182,7 @@ class UnitTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
PadWidths
)
val input =
@@ -203,7 +203,7 @@ class UnitTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
PullMuxes,
ExpandConnects,
RemoveAccesses,
@@ -243,15 +243,15 @@ class UnitTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
| module Unit :
| node x = bits(UInt(1), 100, 0)""".stripMargin
intercept[CheckWidths.BitsWidthException] {
- passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
}
@@ -262,15 +262,15 @@ class UnitTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
| module Unit :
| node x = head(UInt(1), 100)""".stripMargin
intercept[CheckWidths.HeadWidthException] {
- passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
}
@@ -281,15 +281,15 @@ class UnitTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
| module Unit :
| node x = tail(UInt(1), 100)""".stripMargin
intercept[CheckWidths.TailWidthException] {
- passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
}
@@ -308,8 +308,8 @@ class UnitTests extends FirrtlFlatSpec {
| bar <- foo
|""".stripMargin
intercept[PassException] {
- passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
}
@@ -389,7 +389,7 @@ class UnitTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
PullMuxes,
ExpandConnects,
RemoveAccesses,
diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala
index 058cc1fa..96bd249c 100644
--- a/src/test/scala/firrtlTests/WidthSpec.scala
+++ b/src/test/scala/firrtlTests/WidthSpec.scala
@@ -11,10 +11,10 @@ import firrtl.passes._
import firrtl.Parser.IgnoreInfo
class WidthSpec extends FirrtlFlatSpec {
- private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
- val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+ private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = {
+ val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
+ }.circuit
val lines = c.serialize.split("\n") map normalized
expected foreach { e =>
@@ -54,7 +54,7 @@ class WidthSpec extends FirrtlFlatSpec {
InferTypes,
CheckTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -77,7 +77,7 @@ class WidthSpec extends FirrtlFlatSpec {
InferTypes,
CheckTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
s"""circuit Unit :
@@ -96,7 +96,7 @@ class WidthSpec extends FirrtlFlatSpec {
InferTypes,
CheckTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -121,7 +121,7 @@ class WidthSpec extends FirrtlFlatSpec {
InferTypes,
CheckTypes,
ResolveGenders,
- InferWidths)
+ new InferWidths)
val input =
"""circuit Unit :
| module Unit :
@@ -143,7 +143,7 @@ class WidthSpec extends FirrtlFlatSpec {
InferTypes,
CheckTypes,
ResolveGenders,
- InferWidths)
+ new InferWidths)
val input =
"""circuit Unit :
| module Unit :
@@ -166,7 +166,7 @@ class WidthSpec extends FirrtlFlatSpec {
InferTypes,
CheckTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""|circuit Foo :
diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala
index 4f8fd9fe..4fe4a46c 100644
--- a/src/test/scala/firrtlTests/WiringTests.scala
+++ b/src/test/scala/firrtlTests/WiringTests.scala
@@ -14,15 +14,19 @@ import wiring.WiringUtils._
import wiring._
class WiringTests extends FirrtlFlatSpec {
- private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
- val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
- val lines = c.serialize.split("\n") map normalized
+ private def executeTest(input: String,
+ expected: String,
+ passes: Seq[Transform],
+ annos: Seq[Annotation]): Unit = {
+ val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm, annos)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
+ }.circuit
- expected foreach { e =>
- lines should contain(e)
- }
+ (parse(c.serialize).serialize) should be (parse(expected).serialize)
+ }
+
+ private def executeTest(input: String, expected: String, passes: Seq[Transform]): Unit = {
+ executeTest(input, expected, passes, Seq.empty)
}
def passes = Seq(
@@ -30,7 +34,7 @@ class WiringTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths
+ new InferWidths
)
it should "wire from a register source (r) to multiple extmodule sinks (X)" in {
@@ -114,12 +118,9 @@ class WiringTests extends FirrtlFlatSpec {
| input clock: Clock
| input pin: UInt<5>
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "wire from a register source (r) to multiple module sinks (X)" in {
@@ -203,12 +204,9 @@ class WiringTests extends FirrtlFlatSpec {
| input clock: Clock
| input pin: UInt<5>
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "wire from a register sink (r) to a wire source (s) in another module (X)" in {
@@ -295,12 +293,9 @@ class WiringTests extends FirrtlFlatSpec {
| wire s: UInt<5>
| s <= pin
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "wire from a SubField source (r.x) to an extmodule sink (X)" in {
@@ -339,12 +334,9 @@ class WiringTests extends FirrtlFlatSpec {
| input clock: Clock
| input pin: UInt<5>
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "wire properly with a source as a submodule of a sink" in {
@@ -386,12 +378,9 @@ class WiringTests extends FirrtlFlatSpec {
| reg r: UInt<5>, clock
| r_0 <= r
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "wire with source and sink in the same module" in {
@@ -415,12 +404,9 @@ class WiringTests extends FirrtlFlatSpec {
| s <= pin
| pin <= r
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "wire multiple sinks in the same module" in {
@@ -456,12 +442,9 @@ class WiringTests extends FirrtlFlatSpec {
| s <= pin
| pin <= r
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "wire clocks" in {
@@ -498,12 +481,9 @@ class WiringTests extends FirrtlFlatSpec {
| input clock: Clock
| input pin: Clock
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "handle two source instances with clearly defined sinks" in {
@@ -544,12 +524,9 @@ class WiringTests extends FirrtlFlatSpec {
| input clock: Clock
| input pin: Clock
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "wire multiple clocks" in {
@@ -590,12 +567,9 @@ class WiringTests extends FirrtlFlatSpec {
| input clock: Clock
| input pin: Clock
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "error with WiringException for indeterminate ownership" in {
@@ -619,12 +593,10 @@ class WiringTests extends FirrtlFlatSpec {
| extmodule X :
| input clock: Clock
|""".stripMargin
+
intercept[WiringException] {
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
+ executeTest(input, "", passes :+ wiringPass)
}
}
@@ -666,12 +638,9 @@ class WiringTests extends FirrtlFlatSpec {
| input clock: Clock
| input pin: UInt<2>
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(Seq(sas))
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
it should "wire using Annotations with a sink module" in {
@@ -701,12 +670,9 @@ class WiringTests extends FirrtlFlatSpec {
| input clk: Clock
| input pin: UInt<5>
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringXForm = new WiringTransform()
- val retC = wiringXForm.execute(CircuitState(c, MidForm, Seq(source, sink))).circuit
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringXForm, Seq(source, sink))
}
it should "wire using Annotations with a sink component" in {
@@ -739,12 +705,9 @@ class WiringTests extends FirrtlFlatSpec {
| wire s: UInt<5>
| s <= pin
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringXForm = new WiringTransform()
- val retC = wiringXForm.execute(CircuitState(c, MidForm, Seq(source, sink))).circuit
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringXForm, Seq(source, sink))
}
it should "wire using annotations with Aggregate source" in {
@@ -785,12 +748,9 @@ class WiringTests extends FirrtlFlatSpec {
| input clock : Clock
| input pin : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }"""
.stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringXForm = new WiringTransform()
- val retC = wiringXForm.execute(CircuitState(c, MidForm, Seq(source, sink))).circuit
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringXForm, Seq(source, sink))
}
it should "wire one sink to multiple, disjoint extmodules" in {
@@ -845,11 +805,8 @@ class WiringTests extends FirrtlFlatSpec {
| input clock: Clock
| input pin: UInt<5>
|""".stripMargin
- val c = passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+
val wiringPass = new Wiring(wiSeq)
- val retC = wiringPass.run(c)
- (parse(retC.serialize).serialize) should be (parse(check).serialize)
+ executeTest(input, check, passes :+ wiringPass)
}
}
diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala
index 6443e131..eb955f29 100644
--- a/src/test/scala/firrtlTests/ZeroWidthTests.scala
+++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala
@@ -16,7 +16,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
ResolveGenders,
- InferWidths,
+ new InferWidths,
ZeroWidth)
private def exec (input: String) = {
val circuit = parse(input)
diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
index a866836f..667db7b0 100644
--- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
@@ -10,10 +10,10 @@ import firrtl.passes._
import firrtl.Parser.IgnoreInfo
class FixedTypeInferenceSpec extends FirrtlFlatSpec {
- private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
- val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+ private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = {
+ val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
+ }.circuit
val lines = c.serialize.split("\n") map normalized
expected foreach { e =>
@@ -30,7 +30,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -60,7 +60,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -86,7 +86,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -112,7 +112,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -138,7 +138,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -164,7 +164,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -190,7 +190,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -230,7 +230,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths)
val input =
"""circuit Unit :
@@ -256,7 +256,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
ConvertFixedToSInt)
val input =
diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
index 8645fa62..21a39e83 100644
--- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
@@ -9,10 +9,10 @@ import firrtl.passes._
import firrtl.Parser.IgnoreInfo
class RemoveFixedTypeSpec extends FirrtlFlatSpec {
- private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
- val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+ private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = {
+ val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
+ }.circuit
val lines = c.serialize.split("\n") map normalized
println(c.serialize)
@@ -30,7 +30,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
ConvertFixedToSInt)
val input =
@@ -60,7 +60,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
ConvertFixedToSInt)
val input =
@@ -91,7 +91,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
ConvertFixedToSInt)
val input =
@@ -118,7 +118,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
ConvertFixedToSInt)
val input =
@@ -145,7 +145,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
ConvertFixedToSInt)
val input =
@@ -197,7 +197,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
CheckTypes,
ResolveGenders,
CheckGenders,
- InferWidths,
+ new InferWidths,
CheckWidths,
ConvertFixedToSInt)
val input =
diff --git a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala
new file mode 100644
index 00000000..54a3df40
--- /dev/null
+++ b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala
@@ -0,0 +1,230 @@
+// See LICENSE for license details.
+
+package firrtlTests.transforms
+
+import firrtlTests.FirrtlFlatSpec
+import org.scalatest._
+import org.scalatest.prop._
+import firrtl._
+import firrtl.passes._
+import firrtl.passes.wiring.{WiringTransform, SourceAnnotation, SinkAnnotation}
+import firrtl.ir.Circuit
+import firrtl.annotations._
+import firrtl.annotations.TargetToken.{Field, Index}
+
+
+class InferWidthsWithAnnosSpec extends FirrtlFlatSpec {
+ private def executeTest(input: String,
+ check: String,
+ transforms: Seq[Transform],
+ annotations: Seq[Annotation]) = {
+ val start = CircuitState(parse(input), ChirrtlForm, annotations)
+ val end = transforms.foldLeft(start) {
+ (c: CircuitState, t: Transform) => t.runTransform(c)
+ }
+ val resLines = end.circuit.serialize.split("\n") map normalized
+ val checkLines = parse(check).serialize.split("\n") map normalized
+
+ resLines should be (checkLines)
+ }
+
+ "CheckWidths on wires with unknown widths" should "result in an error" in {
+ val transforms = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ ResolveGenders,
+ new InferWidths,
+ CheckWidths)
+
+ val input =
+ """circuit Top :
+ | module Top :
+ | inst b of B
+ | inst a of A
+ |
+ | module B :
+ | wire x: UInt<3>
+ |
+ | module A :
+ | wire y: UInt""".stripMargin
+
+ // A.y should have uninferred width
+ intercept[CheckWidths.UninferredWidth] {
+ executeTest(input, "", transforms, Seq.empty)
+ }
+ }
+
+ "InferWidthsWithAnnos" should "infer widths using WidthGeqConstraintAnnotation" in {
+ val transforms = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ ResolveGenders,
+ new InferWidths,
+ CheckWidths)
+
+ val annos = Seq(WidthGeqConstraintAnnotation(
+ ReferenceTarget("Top", "A", Nil, "y", Nil),
+ ReferenceTarget("Top", "B", Nil, "x", Nil)))
+
+ val input =
+ """circuit Top :
+ | module Top :
+ | inst b of B
+ | inst a of A
+ |
+ | module B :
+ | wire x: UInt<3>
+ |
+ | module A :
+ | wire y: UInt""".stripMargin
+
+ val output =
+ """circuit Top :
+ | module Top :
+ | inst b of B
+ | inst a of A
+ |
+ | module B :
+ | wire x: UInt<3>
+ |
+ | module A :
+ | wire y: UInt<3>""".stripMargin
+
+ // A.y should have same width as B.x
+ executeTest(input, output, transforms, annos)
+ }
+
+ "InferWidthsWithAnnos" should "work with token paths" in {
+ val transforms = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ ResolveGenders,
+ new InferWidths,
+ CheckWidths)
+
+ val tokenLists = Seq(
+ Seq(Field("x")),
+ Seq(Field("y"), Index(0), Field("yy")),
+ Seq(Field("y"), Index(1), Field("yy"))
+ )
+
+ val annos = tokenLists.map { tokens =>
+ WidthGeqConstraintAnnotation(
+ ReferenceTarget("Top", "A", Nil, "bundle", tokens),
+ ReferenceTarget("Top", "B", Nil, "bundle", tokens))
+ }
+
+ val input =
+ """circuit Top :
+ | module Top :
+ | inst b of B
+ | inst a of A
+ |
+ | module B :
+ | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] }
+ |
+ | module A :
+ | wire bundle : {x : UInt, y: {yy : UInt}[2] }""".stripMargin
+
+ val output =
+ """circuit Top :
+ | module Top :
+ | inst b of B
+ | inst a of A
+ |
+ | module B :
+ | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] }
+ |
+ | module A :
+ | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] }""".stripMargin
+
+ // elements of A.bundle should have same width as B.bundle
+ executeTest(input, output, transforms, annos)
+ }
+
+ "InferWidthsWithAnnos" should "work with WiringTransform" in {
+ def transforms = Seq(
+ ToWorkingIR,
+ ResolveKinds,
+ InferTypes,
+ ResolveGenders,
+ new InferWidths,
+ CheckWidths,
+ new WiringTransform,
+ new ResolveAndCheck
+ )
+ val sourceTarget = ComponentName("bundle", ModuleName("A", CircuitName("Top")))
+ val source = SourceAnnotation(sourceTarget, "pin")
+
+ val sinkTarget = ComponentName("bundle", ModuleName("B", CircuitName("Top")))
+ val sink = SinkAnnotation(sinkTarget, "pin")
+
+ val tokenLists = Seq(
+ Seq(Field("x")),
+ Seq(Field("y"), Index(0), Field("yy")),
+ Seq(Field("y"), Index(1), Field("yy"))
+ )
+
+ val wgeqAnnos = tokenLists.map { tokens =>
+ WidthGeqConstraintAnnotation(
+ ReferenceTarget("Top", "A", Nil, "bundle", tokens),
+ ReferenceTarget("Top", "B", Nil, "bundle", tokens))
+ }
+
+ val failAnnos = Seq(source, sink)
+ val successAnnos = wgeqAnnos ++: failAnnos
+
+ val input =
+ """circuit Top :
+ | module Top :
+ | inst b of B
+ | inst a of A
+ |
+ | module B :
+ | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] }
+ |
+ | module A :
+ | wire bundle : {x : UInt, y: {yy : UInt}[2] }""".stripMargin
+
+ val output =
+ """circuit Top :
+ | module Top :
+ | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] }
+ | inst b of B
+ | inst a of A
+ | b.pin <= bundle
+ | bundle <= a.bundle_0
+ |
+ | module B :
+ | input pin : {x : UInt<1>, y: {yy : UInt<3>}[2] }
+ | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] }
+ | bundle <= pin
+ |
+ | module A :
+ | output bundle_0 : {x : UInt<1>, y: {yy : UInt<3>}[2] }
+ | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] }
+ | bundle_0 <= bundle"""
+ .stripMargin
+
+ // should fail without extra constraint annos due to UninferredWidths
+ val exceptions = intercept[PassExceptions] {
+ executeTest(input, "", transforms, failAnnos)
+ }.exceptions.reverse
+
+ val msg = exceptions.head.toString
+ assert(msg.contains(s"2 errors detected!"))
+ assert(exceptions.tail.forall(_.isInstanceOf[CheckWidths.UninferredWidth]))
+
+ // should pass with extra constraints
+ executeTest(input, output, transforms, successAnnos)
+ }
+}