aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKevin Laeufer2020-07-29 13:09:15 -0700
committerGitHub2020-07-29 20:09:15 +0000
commit734e3e462ce74178147d5d6b0b6bdc5557f41103 (patch)
tree02e700a0d6e18dd81a64f9a96b6602e09fc7ca39 /src
parent3c561d4125767406f2b069915ba927190b38e8cd (diff)
InferTypes: fix bugs with unknown widths on ports and memories (#1769)
* InferTypesFlowsAndKindsSpec: test the results of InferTypes, ResolveKinds and ResolveFlows * Don't use passes sub-package in tests This changes two test files using the "passes" sub-package to "firrtl.passes". This allows a new "firrtlTests.passes" package to be freely created and used without a name collision. Signed-off-by: Schuyler Eldridge <schuyler.eldridge@ibm.com> * ResolveFlows: only depends on types and working ir The types are needed to know the orientation of a bundle field of a SubField node. * InferTypes: fix bugs with unknown widths on ports and memories * LoweringCompileSpec: Uniquify pass moved Co-authored-by: Schuyler Eldridge <schuyler.eldridge@ibm.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/InferTypes.scala15
-rw-r--r--src/main/scala/firrtl/passes/ResolveFlows.scala5
-rw-r--r--src/test/scala/firrtlTests/LoweringCompilersSpec.scala118
-rw-r--r--src/test/scala/firrtlTests/RemoveWiresSpec.scala6
-rw-r--r--src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala206
5 files changed, 282 insertions, 68 deletions
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala
index 5524e0ea..6cc9f2b9 100644
--- a/src/main/scala/firrtl/passes/InferTypes.scala
+++ b/src/main/scala/firrtl/passes/InferTypes.scala
@@ -20,7 +20,6 @@ object InferTypes extends Pass {
def run(c: Circuit): Circuit = {
val namespace = Namespace()
- val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
def remove_unknowns_b(b: Bound): Bound = b match {
case UnknownBound => VarBound(namespace.newName("b"))
@@ -40,6 +39,11 @@ object InferTypes extends Pass {
}
}
+ // we first need to remove the unknown widths and bounds from all ports,
+ // as their type will determine the module types
+ val portsKnown = c.modules.map(_.map{ p: Port => p.copy(tpe = remove_unknowns(p.tpe)) })
+ val mtypes = portsKnown.map(m => m.name -> module_type(m)).toMap
+
def infer_types_e(types: TypeLookup)(e: Expression): Expression =
e map infer_types_e(types) match {
case e: WRef => e copy (tpe = types(e.name))
@@ -71,9 +75,10 @@ object InferTypes extends Pass {
types(sx.name) = t
sx copy (tpe = t) map infer_types_e(types)
case sx: DefMemory =>
- val t = remove_unknowns(MemPortUtils.memType(sx))
- types(sx.name) = t
- sx copy (dataType = remove_unknowns(sx.dataType))
+ // we need to remove the unknowns from the data type so that all ports get the same VarWidth
+ val knownDataType = sx.copy(dataType = remove_unknowns(sx.dataType))
+ types(sx.name) = MemPortUtils.memType(knownDataType)
+ knownDataType
case sx => sx map infer_types_s(types) map infer_types_e(types)
}
@@ -88,7 +93,7 @@ object InferTypes extends Pass {
m map infer_types_p(types) map infer_types_s(types)
}
- c copy (modules = c.modules map infer_types)
+ c.copy(modules = portsKnown.map(infer_types))
}
}
diff --git a/src/main/scala/firrtl/passes/ResolveFlows.scala b/src/main/scala/firrtl/passes/ResolveFlows.scala
index c3455327..85a0a26f 100644
--- a/src/main/scala/firrtl/passes/ResolveFlows.scala
+++ b/src/main/scala/firrtl/passes/ResolveFlows.scala
@@ -9,10 +9,7 @@ import firrtl.options.Dependency
object ResolveFlows extends Pass {
- override def prerequisites =
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.Uniquify) ) ++ firrtl.stage.Forms.WorkingIR
+ override def prerequisites = Seq(Dependency(passes.InferTypes)) ++ firrtl.stage.Forms.WorkingIR
override def invalidates(a: Transform) = false
diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
index 854763f1..ae546f7b 100644
--- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
+++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
@@ -5,7 +5,6 @@ package firrtlTests
import org.scalatest.{FlatSpec, Matchers}
import firrtl._
-import firrtl.passes
import firrtl.options.Dependency
import firrtl.stage.{Forms, TransformManager}
@@ -36,75 +35,75 @@ class LoweringCompilersSpec extends FlatSpec with Matchers {
def legacyTransforms(a: CoreTransform): Seq[Transform] = a match {
case _: ChirrtlToHighFirrtl => Seq(
- passes.CheckChirrtl,
- passes.CInferTypes,
- passes.CInferMDir,
- passes.RemoveCHIRRTL)
- case _: IRToWorkingIR => Seq(passes.ToWorkingIR)
+ firrtl.passes.CheckChirrtl,
+ firrtl.passes.CInferTypes,
+ firrtl.passes.CInferMDir,
+ firrtl.passes.RemoveCHIRRTL)
+ case _: IRToWorkingIR => Seq(firrtl.passes.ToWorkingIR)
case _: ResolveAndCheck => Seq(
- passes.CheckHighForm,
- passes.ResolveKinds,
- passes.InferTypes,
- passes.CheckTypes,
- passes.Uniquify,
- passes.ResolveKinds,
- passes.InferTypes,
- passes.ResolveFlows,
- passes.CheckFlows,
- new passes.InferBinaryPoints,
- new passes.TrimIntervals,
- new passes.InferWidths,
- passes.CheckWidths,
+ firrtl.passes.CheckHighForm,
+ firrtl.passes.ResolveKinds,
+ firrtl.passes.InferTypes,
+ firrtl.passes.CheckTypes,
+ firrtl.passes.Uniquify,
+ firrtl.passes.ResolveKinds,
+ firrtl.passes.InferTypes,
+ firrtl.passes.ResolveFlows,
+ firrtl.passes.CheckFlows,
+ new firrtl.passes.InferBinaryPoints,
+ new firrtl.passes.TrimIntervals,
+ new firrtl.passes.InferWidths,
+ firrtl.passes.CheckWidths,
new firrtl.transforms.InferResets)
case _: HighFirrtlToMiddleFirrtl => Seq(
- passes.PullMuxes,
- passes.ReplaceAccesses,
- passes.ExpandConnects,
- passes.ZeroLengthVecs,
- passes.RemoveAccesses,
- passes.Uniquify,
- passes.ExpandWhens,
- passes.CheckInitialization,
- passes.ResolveKinds,
- passes.InferTypes,
- passes.CheckTypes,
- passes.ResolveFlows,
- new passes.InferWidths,
- passes.CheckWidths,
- new passes.RemoveIntervals,
- passes.ConvertFixedToSInt,
- passes.ZeroWidth,
- passes.InferTypes)
+ firrtl.passes.PullMuxes,
+ firrtl.passes.ReplaceAccesses,
+ firrtl.passes.ExpandConnects,
+ firrtl.passes.ZeroLengthVecs,
+ firrtl.passes.RemoveAccesses,
+ firrtl.passes.Uniquify,
+ firrtl.passes.ExpandWhens,
+ firrtl.passes.CheckInitialization,
+ firrtl.passes.ResolveKinds,
+ firrtl.passes.InferTypes,
+ firrtl.passes.CheckTypes,
+ firrtl.passes.ResolveFlows,
+ new firrtl.passes.InferWidths,
+ firrtl.passes.CheckWidths,
+ new firrtl.passes.RemoveIntervals,
+ firrtl.passes.ConvertFixedToSInt,
+ firrtl.passes.ZeroWidth,
+ firrtl.passes.InferTypes)
case _: MiddleFirrtlToLowFirrtl => Seq(
- passes.LowerTypes,
- passes.ResolveKinds,
- passes.InferTypes,
- passes.ResolveFlows,
- new passes.InferWidths,
- passes.Legalize,
+ firrtl.passes.LowerTypes,
+ firrtl.passes.ResolveKinds,
+ firrtl.passes.InferTypes,
+ firrtl.passes.ResolveFlows,
+ new firrtl.passes.InferWidths,
+ firrtl.passes.Legalize,
firrtl.transforms.RemoveReset,
- passes.ResolveFlows,
+ firrtl.passes.ResolveFlows,
new firrtl.transforms.CheckCombLoops,
new checks.CheckResets,
new firrtl.transforms.RemoveWires)
case _: LowFirrtlOptimization => Seq(
- passes.RemoveValidIf,
+ firrtl.passes.RemoveValidIf,
new firrtl.transforms.ConstantPropagation,
- passes.PadWidths,
+ firrtl.passes.PadWidths,
new firrtl.transforms.ConstantPropagation,
- passes.Legalize,
- passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
+ firrtl.passes.Legalize,
+ firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
new firrtl.transforms.ConstantPropagation,
- passes.SplitExpressions,
+ firrtl.passes.SplitExpressions,
new firrtl.transforms.CombineCats,
- passes.CommonSubexpressionElimination,
+ firrtl.passes.CommonSubexpressionElimination,
new firrtl.transforms.DeadCodeElimination)
case _: MinimumLowFirrtlOptimization => Seq(
- passes.RemoveValidIf,
- passes.PadWidths,
- passes.Legalize,
- passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
- passes.SplitExpressions)
+ firrtl.passes.RemoveValidIf,
+ firrtl.passes.PadWidths,
+ firrtl.passes.Legalize,
+ firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
+ firrtl.passes.SplitExpressions)
}
def compare(a: Seq[Transform], b: TransformManager, patches: Seq[PatchAction] = Seq.empty): Unit = {
@@ -147,6 +146,12 @@ class LoweringCompilersSpec extends FlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.Resolved, Forms.WorkingIR)
val patches = Seq(
+ // ResolveFlows no longer depends in Uniquify (ResolveKinds and InferTypes are fixup passes that get moved as well)
+ Del(5), Del(6), Del(7),
+ // Uniquify now is run before InferBinary Points which claims to need Uniquify
+ Add(9, Seq(Dependency(firrtl.passes.Uniquify),
+ Dependency(firrtl.passes.ResolveKinds),
+ Dependency(firrtl.passes.InferTypes))),
Add(14, Seq(Dependency.fromTransform(firrtl.passes.CheckTypes)))
)
compare(legacyTransforms(new ResolveAndCheck), tm, patches)
@@ -164,7 +169,8 @@ class LoweringCompilersSpec extends FlatSpec with Matchers {
Dependency(firrtl.passes.ResolveFlows))),
Del(7),
Del(8),
- Add(7, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])),
+ Add(7, Seq(Dependency(firrtl.passes.ResolveKinds),
+ Dependency[firrtl.passes.ExpandWhensAndCheck])),
Del(11),
Del(12),
Del(13),
diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
index dd3155d0..e6b60059 100644
--- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala
+++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
@@ -163,7 +163,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec {
|c <= n""".stripMargin
)
// Check declaration before use is maintained
- passes.CheckHighForm.execute(result)
+ firrtl.passes.CheckHighForm.execute(result)
}
it should "order registers with async reset correctly" in {
@@ -180,7 +180,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec {
|""".stripMargin
)
// Check declaration before use is maintained
- passes.CheckHighForm.execute(result)
+ firrtl.passes.CheckHighForm.execute(result)
}
it should "order registers respecting initializations" in {
@@ -195,7 +195,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec {
|bar <= y
|""".stripMargin)
// Check declaration before use is maintained
- passes.CheckHighForm.execute(result)
+ firrtl.passes.CheckHighForm.execute(result)
}
}
diff --git a/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala b/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala
new file mode 100644
index 00000000..de638374
--- /dev/null
+++ b/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala
@@ -0,0 +1,206 @@
+// See LICENSE for license details.
+
+package firrtlTests.passes
+
+import firrtl.ir.SubField
+import firrtl.options.Dependency
+import firrtl.stage.TransformManager
+import firrtl.{InstanceKind, MemKind, NodeKind, PortKind, RegKind, WireKind}
+import firrtl.{CircuitState, SinkFlow, SourceFlow, ir, passes}
+import org.scalatest._
+
+/** Tests the combined results of ResolveKinds, InferTypes and ResolveFlows */
+class InferTypesFlowsAndKindsSpec extends FlatSpec {
+ private val deps = Seq(
+ Dependency(passes.ResolveKinds),
+ Dependency(passes.InferTypes),
+ Dependency(passes.ResolveFlows))
+ private val manager = new TransformManager(deps)
+ private def infer(src: String): ir.Circuit =
+ manager.execute(CircuitState(firrtl.Parser.parse(src), Seq())).circuit
+ private def getNodes(s: ir.Statement): Seq[(String, ir.Expression)] = s match {
+ case ir.DefNode(_, name, value) => Seq((name, value))
+ case ir.Block(stmts) => stmts.flatMap(getNodes)
+ case ir.Conditionally(_, _, a, b) => Seq(a,b).flatMap(getNodes)
+ case _ => Seq()
+ }
+ private def getConnects(s: ir.Statement): Seq[ir.Connect] = s match {
+ case c : ir.Connect => Seq(c)
+ case ir.Block(stmts) => stmts.flatMap(getConnects)
+ case ir.Conditionally(_, _, a, b) => Seq(a,b).flatMap(getConnects)
+ case _ => Seq()
+ }
+ private def getModule(c: ir.Circuit, name: String): ir.Module =
+ c.modules.find(_.name == name).get.asInstanceOf[ir.Module]
+
+ it should "infer references to ports, wires, nodes and registers" in {
+ val node = getNodes(getModule(infer(
+ """circuit m:
+ | module m:
+ | input clk: Clock
+ | input a: UInt<4>
+ | wire b : SInt<5>
+ | reg c: UInt<5>, clk
+ | node na = a
+ | node nb = b
+ | node nc = c
+ | node nna = na
+ | node na2 = a
+ | node a_plus_c = add(a, c)
+ |""".stripMargin), "m").body).toMap
+
+ assert(node("na").tpe == ir.UIntType(ir.IntWidth(4)))
+ assert(node("na").asInstanceOf[ir.Reference].flow == SourceFlow)
+ assert(node("na").asInstanceOf[ir.Reference].kind == PortKind)
+
+ assert(node("nb").tpe == ir.SIntType(ir.IntWidth(5)))
+ assert(node("nb").asInstanceOf[ir.Reference].flow == SourceFlow)
+ assert(node("nb").asInstanceOf[ir.Reference].kind == WireKind)
+
+ assert(node("nc").tpe == ir.UIntType(ir.IntWidth(5)))
+ assert(node("nc").asInstanceOf[ir.Reference].flow == SourceFlow)
+ assert(node("nc").asInstanceOf[ir.Reference].kind == RegKind)
+
+ assert(node("nna").tpe == ir.UIntType(ir.IntWidth(4)))
+ assert(node("nna").asInstanceOf[ir.Reference].flow == SourceFlow)
+ assert(node("nna").asInstanceOf[ir.Reference].kind == NodeKind)
+
+ assert(node("na2").tpe == ir.UIntType(ir.IntWidth(4)))
+ assert(node("na2").asInstanceOf[ir.Reference].flow == SourceFlow)
+ assert(node("na2").asInstanceOf[ir.Reference].kind == PortKind)
+
+ // according to the spec, the result of add is max(we1, we2 ) + 1
+ assert(node("a_plus_c").tpe == ir.UIntType(ir.IntWidth(6)))
+ }
+
+ it should "infer types for references to instances" in {
+ val m = getModule(infer(
+ """circuit m:
+ | module other:
+ | output x: { y: UInt, flip z: UInt<1> }
+ | module m:
+ | inst i of other
+ | node i_x = i.x
+ | node i_x_y = i.x.y
+ | node i_x_y_2 = i_x.y
+ | node a = UInt<1>(1)
+ | i.x.z <= a
+ |""".stripMargin), "m")
+ val node = getNodes(m.body).toMap
+ val con = getConnects(m.body)
+
+
+ // node i_x_y = i.x.y
+ assert(node("i_x_y").tpe.isInstanceOf[ir.UIntType])
+ // the type inference replaces all unknown widths with a variable
+ assert(node("i_x_y").tpe.asInstanceOf[ir.UIntType].width.isInstanceOf[ir.VarWidth])
+ assert(node("i_x_y").asInstanceOf[ir.SubField].flow == SourceFlow)
+
+
+ // node i_x = i.x
+ val x = node("i_x").asInstanceOf[ir.SubField]
+ assert(x.tpe.isInstanceOf[ir.BundleType])
+ assert(x.tpe.asInstanceOf[ir.BundleType].fields.head.name == "y")
+ assert(x.tpe.asInstanceOf[ir.BundleType].fields.head.tpe == node("i_x_y").tpe)
+ assert(x.tpe.asInstanceOf[ir.BundleType].fields.head.flip == ir.Default)
+ assert(x.tpe.asInstanceOf[ir.BundleType].fields.last.flip == ir.Flip)
+ assert(x.flow == SourceFlow)
+
+ val i = x.expr.asInstanceOf[ir.Reference]
+ assert(i.kind == InstanceKind)
+ assert(i.flow == SourceFlow)
+
+
+ // node i_x_y_2 = i_x.y
+ assert(node("i_x_y").tpe == node("i_x_y_2").tpe)
+ assert(node("i_x_y").asInstanceOf[ir.SubField].flow == node("i_x_y_2").asInstanceOf[ir.SubField].flow)
+
+
+ // i.x.z <= a
+ val (left, right) = (con.head.loc.asInstanceOf[ir.SubField], con.head.expr.asInstanceOf[ir.Reference])
+
+ // flow propagates z -> x -> i
+ assert(left.flow == SinkFlow)
+ val left_x = left.expr.asInstanceOf[SubField]
+ assert(left_x.flow == SourceFlow) // flip z
+ val left_i = left_x.expr.asInstanceOf[ir.Reference]
+ assert(left_i.flow == SourceFlow)
+
+ assert(left_i.kind == InstanceKind)
+ assert(left_x.tpe == x.tpe)
+ }
+
+ it should "infer types for references to memories" in {
+ val c = infer(
+ """circuit m:
+ | module m:
+ | mem m:
+ | data-type => UInt
+ | depth => 30
+ | reader => r
+ | writer => w
+ | read-latency => 1
+ | write-latency => 1
+ | read-under-write => undefined
+ |
+ | node m_r_addr = m.r.addr
+ | node m_r_data = m.r.data
+ | node m_w_addr = m.w.addr
+ | node m_w_data = m.w.data
+ |""".stripMargin)
+ val m = getModule(c, "m")
+ val node = getNodes(m.body).toMap
+ // this might be a little flaky...
+ val memory = m.body.asInstanceOf[ir.Block].stmts.head.asInstanceOf[ir.DefMemory]
+
+
+ // after InferTypes, all expressions referring to the `data` should have this type:
+ val dataTpe = memory.dataType.asInstanceOf[ir.UIntType]
+ val addrTpe = ir.UIntType(ir.IntWidth(5))
+
+ assert(node("m_r_addr").tpe == addrTpe)
+ assert(node("m_r_data").tpe == dataTpe)
+ assert(node("m_w_addr").tpe == addrTpe)
+ assert(node("m_w_data").tpe == dataTpe)
+
+ val memory_ref = node("m_r_addr").asInstanceOf[ir.SubField].expr
+ .asInstanceOf[ir.SubField].expr.asInstanceOf[ir.Reference]
+ assert(memory_ref.kind == MemKind)
+ val mem_ref_tpe = memory_ref.tpe.asInstanceOf[ir.BundleType]
+ val r_tpe = mem_ref_tpe.fields.find(_.name == "r").get.tpe.asInstanceOf[ir.BundleType]
+ val w_tpe = mem_ref_tpe.fields.find(_.name == "w").get.tpe.asInstanceOf[ir.BundleType]
+ assert(r_tpe.fields.find(_.name == "addr").get.tpe == addrTpe)
+ assert(r_tpe.fields.find(_.name == "data").get.tpe == dataTpe)
+ assert(w_tpe.fields.find(_.name == "addr").get.tpe == addrTpe)
+ assert(w_tpe.fields.find(_.name == "data").get.tpe == dataTpe)
+ }
+
+ it should "infer different instances of the same module to have the same width variable" in {
+ val c = infer(
+ """circuit m:
+ | module other:
+ | input x: UInt
+ | module x:
+ | inst i of other
+ | i.x <= UInt<16>(3)
+ | module m:
+ | inst x of x
+ | inst i of other
+ | i.x <= UInt<1>(1)
+ |""".stripMargin)
+ val m_con = getConnects(getModule(c, "m").body).head
+ val x_con = getConnects(getModule(c, "x").body).head
+ val other = getModule(c, "other")
+
+ // this is the type of the other.x port
+ val tpe = m_con.loc.tpe.asInstanceOf[ir.UIntType]
+ assert(tpe.width.isInstanceOf[ir.VarWidth])
+ // since it is the only unknown width, it should just be replaced with a "w"
+ assert(tpe.width.asInstanceOf[ir.VarWidth].name == "w")
+
+ assert(m_con.loc.tpe == tpe)
+ assert(x_con.loc.tpe == tpe)
+ assert(other.ports.head.tpe == tpe)
+ }
+
+}