diff options
| author | Kevin Laeufer | 2020-07-29 13:09:15 -0700 |
|---|---|---|
| committer | GitHub | 2020-07-29 20:09:15 +0000 |
| commit | 734e3e462ce74178147d5d6b0b6bdc5557f41103 (patch) | |
| tree | 02e700a0d6e18dd81a64f9a96b6602e09fc7ca39 /src | |
| parent | 3c561d4125767406f2b069915ba927190b38e8cd (diff) | |
InferTypes: fix bugs with unknown widths on ports and memories (#1769)
* InferTypesFlowsAndKindsSpec: test the results of InferTypes, ResolveKinds and ResolveFlows
* Don't use passes sub-package in tests
This changes two test files using the "passes" sub-package to
"firrtl.passes". This allows a new "firrtlTests.passes" package to be
freely created and used without a name collision.
Signed-off-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
* ResolveFlows: only depends on types and working ir
The types are needed to know the orientation of
a bundle field of a SubField node.
* InferTypes: fix bugs with unknown widths on ports and memories
* LoweringCompileSpec: Uniquify pass moved
Co-authored-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src')
5 files changed, 282 insertions, 68 deletions
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 5524e0ea..6cc9f2b9 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -20,7 +20,6 @@ object InferTypes extends Pass { def run(c: Circuit): Circuit = { val namespace = Namespace() - val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap def remove_unknowns_b(b: Bound): Bound = b match { case UnknownBound => VarBound(namespace.newName("b")) @@ -40,6 +39,11 @@ object InferTypes extends Pass { } } + // we first need to remove the unknown widths and bounds from all ports, + // as their type will determine the module types + val portsKnown = c.modules.map(_.map{ p: Port => p.copy(tpe = remove_unknowns(p.tpe)) }) + val mtypes = portsKnown.map(m => m.name -> module_type(m)).toMap + def infer_types_e(types: TypeLookup)(e: Expression): Expression = e map infer_types_e(types) match { case e: WRef => e copy (tpe = types(e.name)) @@ -71,9 +75,10 @@ object InferTypes extends Pass { types(sx.name) = t sx copy (tpe = t) map infer_types_e(types) case sx: DefMemory => - val t = remove_unknowns(MemPortUtils.memType(sx)) - types(sx.name) = t - sx copy (dataType = remove_unknowns(sx.dataType)) + // we need to remove the unknowns from the data type so that all ports get the same VarWidth + val knownDataType = sx.copy(dataType = remove_unknowns(sx.dataType)) + types(sx.name) = MemPortUtils.memType(knownDataType) + knownDataType case sx => sx map infer_types_s(types) map infer_types_e(types) } @@ -88,7 +93,7 @@ object InferTypes extends Pass { m map infer_types_p(types) map infer_types_s(types) } - c copy (modules = c.modules map infer_types) + c.copy(modules = portsKnown.map(infer_types)) } } diff --git a/src/main/scala/firrtl/passes/ResolveFlows.scala b/src/main/scala/firrtl/passes/ResolveFlows.scala index c3455327..85a0a26f 100644 --- a/src/main/scala/firrtl/passes/ResolveFlows.scala +++ b/src/main/scala/firrtl/passes/ResolveFlows.scala @@ -9,10 +9,7 @@ import firrtl.options.Dependency object ResolveFlows extends Pass { - override def prerequisites = - Seq( Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.Uniquify) ) ++ firrtl.stage.Forms.WorkingIR + override def prerequisites = Seq(Dependency(passes.InferTypes)) ++ firrtl.stage.Forms.WorkingIR override def invalidates(a: Transform) = false diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index 854763f1..ae546f7b 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -5,7 +5,6 @@ package firrtlTests import org.scalatest.{FlatSpec, Matchers} import firrtl._ -import firrtl.passes import firrtl.options.Dependency import firrtl.stage.{Forms, TransformManager} @@ -36,75 +35,75 @@ class LoweringCompilersSpec extends FlatSpec with Matchers { def legacyTransforms(a: CoreTransform): Seq[Transform] = a match { case _: ChirrtlToHighFirrtl => Seq( - passes.CheckChirrtl, - passes.CInferTypes, - passes.CInferMDir, - passes.RemoveCHIRRTL) - case _: IRToWorkingIR => Seq(passes.ToWorkingIR) + firrtl.passes.CheckChirrtl, + firrtl.passes.CInferTypes, + firrtl.passes.CInferMDir, + firrtl.passes.RemoveCHIRRTL) + case _: IRToWorkingIR => Seq(firrtl.passes.ToWorkingIR) case _: ResolveAndCheck => Seq( - passes.CheckHighForm, - passes.ResolveKinds, - passes.InferTypes, - passes.CheckTypes, - passes.Uniquify, - passes.ResolveKinds, - passes.InferTypes, - passes.ResolveFlows, - passes.CheckFlows, - new passes.InferBinaryPoints, - new passes.TrimIntervals, - new passes.InferWidths, - passes.CheckWidths, + firrtl.passes.CheckHighForm, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.CheckTypes, + firrtl.passes.Uniquify, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.ResolveFlows, + firrtl.passes.CheckFlows, + new firrtl.passes.InferBinaryPoints, + new firrtl.passes.TrimIntervals, + new firrtl.passes.InferWidths, + firrtl.passes.CheckWidths, new firrtl.transforms.InferResets) case _: HighFirrtlToMiddleFirrtl => Seq( - passes.PullMuxes, - passes.ReplaceAccesses, - passes.ExpandConnects, - passes.ZeroLengthVecs, - passes.RemoveAccesses, - passes.Uniquify, - passes.ExpandWhens, - passes.CheckInitialization, - passes.ResolveKinds, - passes.InferTypes, - passes.CheckTypes, - passes.ResolveFlows, - new passes.InferWidths, - passes.CheckWidths, - new passes.RemoveIntervals, - passes.ConvertFixedToSInt, - passes.ZeroWidth, - passes.InferTypes) + firrtl.passes.PullMuxes, + firrtl.passes.ReplaceAccesses, + firrtl.passes.ExpandConnects, + firrtl.passes.ZeroLengthVecs, + firrtl.passes.RemoveAccesses, + firrtl.passes.Uniquify, + firrtl.passes.ExpandWhens, + firrtl.passes.CheckInitialization, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.CheckTypes, + firrtl.passes.ResolveFlows, + new firrtl.passes.InferWidths, + firrtl.passes.CheckWidths, + new firrtl.passes.RemoveIntervals, + firrtl.passes.ConvertFixedToSInt, + firrtl.passes.ZeroWidth, + firrtl.passes.InferTypes) case _: MiddleFirrtlToLowFirrtl => Seq( - passes.LowerTypes, - passes.ResolveKinds, - passes.InferTypes, - passes.ResolveFlows, - new passes.InferWidths, - passes.Legalize, + firrtl.passes.LowerTypes, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.ResolveFlows, + new firrtl.passes.InferWidths, + firrtl.passes.Legalize, firrtl.transforms.RemoveReset, - passes.ResolveFlows, + firrtl.passes.ResolveFlows, new firrtl.transforms.CheckCombLoops, new checks.CheckResets, new firrtl.transforms.RemoveWires) case _: LowFirrtlOptimization => Seq( - passes.RemoveValidIf, + firrtl.passes.RemoveValidIf, new firrtl.transforms.ConstantPropagation, - passes.PadWidths, + firrtl.passes.PadWidths, new firrtl.transforms.ConstantPropagation, - passes.Legalize, - passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter + firrtl.passes.Legalize, + firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter new firrtl.transforms.ConstantPropagation, - passes.SplitExpressions, + firrtl.passes.SplitExpressions, new firrtl.transforms.CombineCats, - passes.CommonSubexpressionElimination, + firrtl.passes.CommonSubexpressionElimination, new firrtl.transforms.DeadCodeElimination) case _: MinimumLowFirrtlOptimization => Seq( - passes.RemoveValidIf, - passes.PadWidths, - passes.Legalize, - passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - passes.SplitExpressions) + firrtl.passes.RemoveValidIf, + firrtl.passes.PadWidths, + firrtl.passes.Legalize, + firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter + firrtl.passes.SplitExpressions) } def compare(a: Seq[Transform], b: TransformManager, patches: Seq[PatchAction] = Seq.empty): Unit = { @@ -147,6 +146,12 @@ class LoweringCompilersSpec extends FlatSpec with Matchers { it should "replicate the old order" in { val tm = new TransformManager(Forms.Resolved, Forms.WorkingIR) val patches = Seq( + // ResolveFlows no longer depends in Uniquify (ResolveKinds and InferTypes are fixup passes that get moved as well) + Del(5), Del(6), Del(7), + // Uniquify now is run before InferBinary Points which claims to need Uniquify + Add(9, Seq(Dependency(firrtl.passes.Uniquify), + Dependency(firrtl.passes.ResolveKinds), + Dependency(firrtl.passes.InferTypes))), Add(14, Seq(Dependency.fromTransform(firrtl.passes.CheckTypes))) ) compare(legacyTransforms(new ResolveAndCheck), tm, patches) @@ -164,7 +169,8 @@ class LoweringCompilersSpec extends FlatSpec with Matchers { Dependency(firrtl.passes.ResolveFlows))), Del(7), Del(8), - Add(7, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])), + Add(7, Seq(Dependency(firrtl.passes.ResolveKinds), + Dependency[firrtl.passes.ExpandWhensAndCheck])), Del(11), Del(12), Del(13), diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala index dd3155d0..e6b60059 100644 --- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala +++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala @@ -163,7 +163,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec { |c <= n""".stripMargin ) // Check declaration before use is maintained - passes.CheckHighForm.execute(result) + firrtl.passes.CheckHighForm.execute(result) } it should "order registers with async reset correctly" in { @@ -180,7 +180,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec { |""".stripMargin ) // Check declaration before use is maintained - passes.CheckHighForm.execute(result) + firrtl.passes.CheckHighForm.execute(result) } it should "order registers respecting initializations" in { @@ -195,7 +195,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec { |bar <= y |""".stripMargin) // Check declaration before use is maintained - passes.CheckHighForm.execute(result) + firrtl.passes.CheckHighForm.execute(result) } } diff --git a/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala b/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala new file mode 100644 index 00000000..de638374 --- /dev/null +++ b/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala @@ -0,0 +1,206 @@ +// See LICENSE for license details. + +package firrtlTests.passes + +import firrtl.ir.SubField +import firrtl.options.Dependency +import firrtl.stage.TransformManager +import firrtl.{InstanceKind, MemKind, NodeKind, PortKind, RegKind, WireKind} +import firrtl.{CircuitState, SinkFlow, SourceFlow, ir, passes} +import org.scalatest._ + +/** Tests the combined results of ResolveKinds, InferTypes and ResolveFlows */ +class InferTypesFlowsAndKindsSpec extends FlatSpec { + private val deps = Seq( + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.ResolveFlows)) + private val manager = new TransformManager(deps) + private def infer(src: String): ir.Circuit = + manager.execute(CircuitState(firrtl.Parser.parse(src), Seq())).circuit + private def getNodes(s: ir.Statement): Seq[(String, ir.Expression)] = s match { + case ir.DefNode(_, name, value) => Seq((name, value)) + case ir.Block(stmts) => stmts.flatMap(getNodes) + case ir.Conditionally(_, _, a, b) => Seq(a,b).flatMap(getNodes) + case _ => Seq() + } + private def getConnects(s: ir.Statement): Seq[ir.Connect] = s match { + case c : ir.Connect => Seq(c) + case ir.Block(stmts) => stmts.flatMap(getConnects) + case ir.Conditionally(_, _, a, b) => Seq(a,b).flatMap(getConnects) + case _ => Seq() + } + private def getModule(c: ir.Circuit, name: String): ir.Module = + c.modules.find(_.name == name).get.asInstanceOf[ir.Module] + + it should "infer references to ports, wires, nodes and registers" in { + val node = getNodes(getModule(infer( + """circuit m: + | module m: + | input clk: Clock + | input a: UInt<4> + | wire b : SInt<5> + | reg c: UInt<5>, clk + | node na = a + | node nb = b + | node nc = c + | node nna = na + | node na2 = a + | node a_plus_c = add(a, c) + |""".stripMargin), "m").body).toMap + + assert(node("na").tpe == ir.UIntType(ir.IntWidth(4))) + assert(node("na").asInstanceOf[ir.Reference].flow == SourceFlow) + assert(node("na").asInstanceOf[ir.Reference].kind == PortKind) + + assert(node("nb").tpe == ir.SIntType(ir.IntWidth(5))) + assert(node("nb").asInstanceOf[ir.Reference].flow == SourceFlow) + assert(node("nb").asInstanceOf[ir.Reference].kind == WireKind) + + assert(node("nc").tpe == ir.UIntType(ir.IntWidth(5))) + assert(node("nc").asInstanceOf[ir.Reference].flow == SourceFlow) + assert(node("nc").asInstanceOf[ir.Reference].kind == RegKind) + + assert(node("nna").tpe == ir.UIntType(ir.IntWidth(4))) + assert(node("nna").asInstanceOf[ir.Reference].flow == SourceFlow) + assert(node("nna").asInstanceOf[ir.Reference].kind == NodeKind) + + assert(node("na2").tpe == ir.UIntType(ir.IntWidth(4))) + assert(node("na2").asInstanceOf[ir.Reference].flow == SourceFlow) + assert(node("na2").asInstanceOf[ir.Reference].kind == PortKind) + + // according to the spec, the result of add is max(we1, we2 ) + 1 + assert(node("a_plus_c").tpe == ir.UIntType(ir.IntWidth(6))) + } + + it should "infer types for references to instances" in { + val m = getModule(infer( + """circuit m: + | module other: + | output x: { y: UInt, flip z: UInt<1> } + | module m: + | inst i of other + | node i_x = i.x + | node i_x_y = i.x.y + | node i_x_y_2 = i_x.y + | node a = UInt<1>(1) + | i.x.z <= a + |""".stripMargin), "m") + val node = getNodes(m.body).toMap + val con = getConnects(m.body) + + + // node i_x_y = i.x.y + assert(node("i_x_y").tpe.isInstanceOf[ir.UIntType]) + // the type inference replaces all unknown widths with a variable + assert(node("i_x_y").tpe.asInstanceOf[ir.UIntType].width.isInstanceOf[ir.VarWidth]) + assert(node("i_x_y").asInstanceOf[ir.SubField].flow == SourceFlow) + + + // node i_x = i.x + val x = node("i_x").asInstanceOf[ir.SubField] + assert(x.tpe.isInstanceOf[ir.BundleType]) + assert(x.tpe.asInstanceOf[ir.BundleType].fields.head.name == "y") + assert(x.tpe.asInstanceOf[ir.BundleType].fields.head.tpe == node("i_x_y").tpe) + assert(x.tpe.asInstanceOf[ir.BundleType].fields.head.flip == ir.Default) + assert(x.tpe.asInstanceOf[ir.BundleType].fields.last.flip == ir.Flip) + assert(x.flow == SourceFlow) + + val i = x.expr.asInstanceOf[ir.Reference] + assert(i.kind == InstanceKind) + assert(i.flow == SourceFlow) + + + // node i_x_y_2 = i_x.y + assert(node("i_x_y").tpe == node("i_x_y_2").tpe) + assert(node("i_x_y").asInstanceOf[ir.SubField].flow == node("i_x_y_2").asInstanceOf[ir.SubField].flow) + + + // i.x.z <= a + val (left, right) = (con.head.loc.asInstanceOf[ir.SubField], con.head.expr.asInstanceOf[ir.Reference]) + + // flow propagates z -> x -> i + assert(left.flow == SinkFlow) + val left_x = left.expr.asInstanceOf[SubField] + assert(left_x.flow == SourceFlow) // flip z + val left_i = left_x.expr.asInstanceOf[ir.Reference] + assert(left_i.flow == SourceFlow) + + assert(left_i.kind == InstanceKind) + assert(left_x.tpe == x.tpe) + } + + it should "infer types for references to memories" in { + val c = infer( + """circuit m: + | module m: + | mem m: + | data-type => UInt + | depth => 30 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | node m_r_addr = m.r.addr + | node m_r_data = m.r.data + | node m_w_addr = m.w.addr + | node m_w_data = m.w.data + |""".stripMargin) + val m = getModule(c, "m") + val node = getNodes(m.body).toMap + // this might be a little flaky... + val memory = m.body.asInstanceOf[ir.Block].stmts.head.asInstanceOf[ir.DefMemory] + + + // after InferTypes, all expressions referring to the `data` should have this type: + val dataTpe = memory.dataType.asInstanceOf[ir.UIntType] + val addrTpe = ir.UIntType(ir.IntWidth(5)) + + assert(node("m_r_addr").tpe == addrTpe) + assert(node("m_r_data").tpe == dataTpe) + assert(node("m_w_addr").tpe == addrTpe) + assert(node("m_w_data").tpe == dataTpe) + + val memory_ref = node("m_r_addr").asInstanceOf[ir.SubField].expr + .asInstanceOf[ir.SubField].expr.asInstanceOf[ir.Reference] + assert(memory_ref.kind == MemKind) + val mem_ref_tpe = memory_ref.tpe.asInstanceOf[ir.BundleType] + val r_tpe = mem_ref_tpe.fields.find(_.name == "r").get.tpe.asInstanceOf[ir.BundleType] + val w_tpe = mem_ref_tpe.fields.find(_.name == "w").get.tpe.asInstanceOf[ir.BundleType] + assert(r_tpe.fields.find(_.name == "addr").get.tpe == addrTpe) + assert(r_tpe.fields.find(_.name == "data").get.tpe == dataTpe) + assert(w_tpe.fields.find(_.name == "addr").get.tpe == addrTpe) + assert(w_tpe.fields.find(_.name == "data").get.tpe == dataTpe) + } + + it should "infer different instances of the same module to have the same width variable" in { + val c = infer( + """circuit m: + | module other: + | input x: UInt + | module x: + | inst i of other + | i.x <= UInt<16>(3) + | module m: + | inst x of x + | inst i of other + | i.x <= UInt<1>(1) + |""".stripMargin) + val m_con = getConnects(getModule(c, "m").body).head + val x_con = getConnects(getModule(c, "x").body).head + val other = getModule(c, "other") + + // this is the type of the other.x port + val tpe = m_con.loc.tpe.asInstanceOf[ir.UIntType] + assert(tpe.width.isInstanceOf[ir.VarWidth]) + // since it is the only unknown width, it should just be replaced with a "w" + assert(tpe.width.asInstanceOf[ir.VarWidth].name == "w") + + assert(m_con.loc.tpe == tpe) + assert(x_con.loc.tpe == tpe) + assert(other.ports.head.tpe == tpe) + } + +} |
