diff options
Diffstat (limited to 'src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala')
| -rw-r--r-- | src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala new file mode 100644 index 00000000..c991941f --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> +package firrtl.backends.experimental.smt + +private object SMTExprMap { + def mapExpr(expr: SMTExpr, f: SMTExpr => SMTExpr): SMTExpr = { + val bv = (b: BVExpr) => f(b).asInstanceOf[BVExpr] + val ar = (a: ArrayExpr) => f(a).asInstanceOf[ArrayExpr] + expr match { + case b: BVExpr => mapExpr(b, bv, ar) + case a: ArrayExpr => mapExpr(a, bv, ar) + } + } + + /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */ + def mapExpr(expr: BVExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): BVExpr = expr match { + // nullary + case old: BVLiteral => old + case old: BVSymbol => old + // unary + case old @ BVExtend(e, by, signed) => val n = bv(e); if (n.eq(e)) old else BVExtend(n, by, signed) + case old @ BVSlice(e, hi, lo) => val n = bv(e); if (n.eq(e)) old else BVSlice(n, hi, lo) + case old @ BVNot(e) => val n = bv(e); if (n.eq(e)) old else BVNot(n) + case old @ BVNegate(e) => val n = bv(e); if (n.eq(e)) old else BVNegate(n) + case old @ BVForall(variables, e) => val n = bv(e); if (n.eq(e)) old else BVForall(variables, n) + case old @ BVReduceAnd(e) => val n = bv(e); if (n.eq(e)) old else BVReduceAnd(n) + case old @ BVReduceOr(e) => val n = bv(e); if (n.eq(e)) old else BVReduceOr(n) + case old @ BVReduceXor(e) => val n = bv(e); if (n.eq(e)) old else BVReduceXor(n) + // binary + case old @ BVEqual(a, b) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB) + case old @ ArrayEqual(a, b) => + val (nA, nB) = (ar(a), ar(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB) + case old @ BVComparison(op, a, b, signed) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed) + case old @ BVOp(op, a, b) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB) + case old @ BVConcat(a, b) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB) + case old @ ArrayRead(a, b) => + val (nA, nB) = (ar(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB) + case old @ BVImplies(a, b) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB) + // ternary + case old @ BVIte(a, b, c) => + val (nA, nB, nC) = (bv(a), bv(b), bv(c)) + if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC) + // n-ary + case old @ BVFunctionCall(name, args, width) => + val nArgs = args.map { + case b: BVExpr => bv(b) + case a: ArrayExpr => ar(a) + case u: UTSymbol => u + } + val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) } + if (anyNew) BVFunctionCall(name, nArgs, width) else old + case old @ BVAnd(terms) => + val nTerms = terms.map(bv) + val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) } + if (anyNew) BVAnd(nTerms) else old + case old @ BVOr(terms) => + val nTerms = terms.map(bv) + val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) } + if (anyNew) BVOr(nTerms) else old + } + + /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */ + def mapExpr(expr: ArrayExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): ArrayExpr = expr match { + case old: ArraySymbol => old + case old @ ArrayConstant(e, indexWidth) => val n = bv(e); if (n.eq(e)) old else ArrayConstant(n, indexWidth) + case old @ ArrayStore(a, b, c) => + val (nA, nB, nC) = (ar(a), bv(b), bv(c)) + if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC) + case old @ ArrayIte(a, b, c) => + val (nA, nB, nC) = (bv(a), ar(b), ar(c)) + if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC) + case old @ ArrayFunctionCall(name, args, indexWidth, dataWidth) => + val nArgs = args.map { + case b: BVExpr => bv(b) + case a: ArrayExpr => ar(a) + case u: UTSymbol => u + } + val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) } + if (anyNew) ArrayFunctionCall(name, nArgs, indexWidth, dataWidth) else old + } +} |
