From d7bd6d64fa9561d2049e565c61ee989eea1a61a5 Mon Sep 17 00:00:00 2001 From: Albert Magyar Date: Tue, 26 May 2020 13:59:53 -0700 Subject: Use recursive-then-iterative approach for check_width_e * Avoid excessively deep recursion * Avoid overhead of DFS for shallow expression trees * Reduce work: skip expressions that cannot contain error-containing subtrees * Review feedback: added commentary to explain new check_widths_e structure --- src/main/scala/firrtl/passes/CheckWidths.scala | 85 ++++++++++++++++---------- 1 file changed, 54 insertions(+), 31 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 3c7ad0a8..4f1930c1 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -109,46 +109,69 @@ object CheckWidths extends Pass with PreservesAll[Transform] { def check_width_f(info: Info, target: Target)(f: Field): Unit = check_width_t(info, target.modify(tokens = target.tokens :+ TargetToken.Field(f.name)))(f.tpe) - def check_width_e(info: Info, target: Target)(e: Expression): Unit = { - e match { - case e: UIntLiteral => e.width match { - case w: IntWidth if math.max(1, e.value.bitLength) > w.width => - errors.append(new WidthTooSmall(info, target.serialize, e.value)) - case _ => - } - case e: SIntLiteral => e.width match { - case w: IntWidth if e.value.bitLength + 1 > w.width => - errors.append(new WidthTooSmall(info, target.serialize, e.value)) - case _ => - } - case sqz@DoPrim(Squeeze, Seq(a, b), _, IntervalType(Closed(min), Closed(max), _)) => - (a.tpe, b.tpe) match { - case (IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) => - errors.append(new DisjointSqueeze(info, target.serialize, sqz)) - case other => + def check_width_e_leaf(info: Info, target: Target, expr: Expression): Unit = { + // This is a leaf check of the "local" width-correctness of one expression node, so no recursion. + expr match { + case e @ UIntLiteral(v, w: IntWidth) if math.max(1, v.bitLength) > w.width => + errors.append(new WidthTooSmall(info, target.serialize, v)) + case e @ SIntLiteral(v, w: IntWidth) if v.bitLength + 1 > w.width => + errors.append(new WidthTooSmall(info, target.serialize, v)) + case e @ DoPrim(op, Seq(a, b), _, tpe) => + (op, a.tpe, b.tpe) match { + case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) => + errors.append(new DisjointSqueeze(info, target.serialize, e)) + case (Dshl, at, bt) if (hasWidth(at) && bitWidth(bt) >= DshlMaxWidth) => + errors.append(new DshlTooBig(info, target.serialize)) + case _ => + } + case e @ DoPrim(op, Seq(a), consts, _) => + (op, consts) match { + case (Bits, Seq(hi, lo)) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) => + errors.append(new BitsWidthException(info, target.serialize, hi, bitWidth(a.tpe), e.serialize)) + case (Head, Seq(n)) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => + errors.append(new HeadWidthException(info, target.serialize, n, bitWidth(a.tpe))) + case (Tail, Seq(n)) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => + errors.append(new TailWidthException(info, target.serialize, n, bitWidth(a.tpe))) + case (AsClock, _) if (bitWidth(a.tpe) != 1) => + errors.append(new MultiBitAsClock(info, target.serialize)) + case (AsAsyncReset, _) if (bitWidth(a.tpe) != 1) => + errors.append(new MultiBitAsAsyncReset(info, target.serialize)) + case _ => } - case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) => - errors.append(new BitsWidthException(info, target.serialize, hi, bitWidth(a.tpe), e.serialize)) - case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => - errors.append(new HeadWidthException(info, target.serialize, n, bitWidth(a.tpe))) - case DoPrim(Tail, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => - errors.append(new TailWidthException(info, target.serialize, n, bitWidth(a.tpe))) - case DoPrim(Dshl, Seq(a, b), _, _) if (hasWidth(a.tpe) && bitWidth(b.tpe) >= DshlMaxWidth) => - errors.append(new DshlTooBig(info, target.serialize)) - case DoPrim(AsClock, Seq(a), _, _) if (bitWidth(a.tpe) != 1) => - errors.append(new MultiBitAsClock(info, target.serialize)) - case DoPrim(AsAsyncReset, Seq(a), _, _) if (bitWidth(a.tpe) != 1) => - errors.append(new MultiBitAsAsyncReset(info, target.serialize)) case _ => } - e foreach check_width_e(info, target) } + def check_width_e(info: Info, target: Target, recDepth: Integer)(expr: Expression): Unit = { + check_width_e_leaf(info, target, expr) + expr match { + case _: Mux | _: ValidIf | _: DoPrim => + // Width errors only occur on lit / Mux / DoPrim; only Mux, ValidIf, and DoPrim can have these as children + if (recDepth > 0) + // Use recursion for trees up to a nominal depth to avoid overhead + expr.foreach(check_width_e(info, target, recDepth - 1)) + else + // Beyond that, switch to an explicit, iterative DFS to avoid stack overflow + check_width_e_dfs(info, target, expr) + case _ => // Anything else can only contain references / sub{fields, indices, accesses} + } + } + + + def check_width_e_dfs(info: Info, target: Target, expr: Expression): Unit = { + val stack = collection.mutable.ArrayStack(expr) + def push(e: Expression): Unit = stack.push(e) + while (stack.nonEmpty) { + val current = stack.pop() + check_width_e_leaf(info, target, current) + current.foreach(push) + } + } def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Unit = { val info = get_info(s) match { case NoInfo => minfo case x => x } val subRef = s match { case sx: HasName => target.ref(sx.name) case _ => target } - s foreach check_width_e(info, target) + s foreach check_width_e(info, target, 4) s foreach check_width_s(info, target) s foreach check_width_t(info, subRef) s match { -- cgit v1.2.3