1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
// SPDX-License-Identifier: Apache-2.0
// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
package firrtl.backends.experimental.smt
object SMTExprMap {
/** maps f over subexpressions of expr and returns expr with the results replaced */
def mapExpr(expr: SMTExpr, f: SMTExpr => SMTExpr): SMTExpr = {
val bv = (b: BVExpr) => f(b).asInstanceOf[BVExpr]
val ar = (a: ArrayExpr) => f(a).asInstanceOf[ArrayExpr]
expr match {
case b: BVExpr => mapExpr(b, bv, ar)
case a: ArrayExpr => mapExpr(a, bv, ar)
}
}
/** maps bv/ar over subexpressions of expr and returns expr with the results replaced */
def mapExpr(expr: BVExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): BVExpr = expr match {
// nullary
case old: BVLiteral => old
case old: BVSymbol => old
// unary
case old @ BVExtend(e, by, signed) => val n = bv(e); if (n.eq(e)) old else BVExtend(n, by, signed)
case old @ BVSlice(e, hi, lo) => val n = bv(e); if (n.eq(e)) old else BVSlice(n, hi, lo)
case old @ BVNot(e) => val n = bv(e); if (n.eq(e)) old else BVNot(n)
case old @ BVNegate(e) => val n = bv(e); if (n.eq(e)) old else BVNegate(n)
case old @ BVForall(variables, e) => val n = bv(e); if (n.eq(e)) old else BVForall(variables, n)
case old @ BVReduceAnd(e) => val n = bv(e); if (n.eq(e)) old else BVReduceAnd(n)
case old @ BVReduceOr(e) => val n = bv(e); if (n.eq(e)) old else BVReduceOr(n)
case old @ BVReduceXor(e) => val n = bv(e); if (n.eq(e)) old else BVReduceXor(n)
// binary
case old @ BVEqual(a, b) =>
val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB)
case old @ ArrayEqual(a, b) =>
val (nA, nB) = (ar(a), ar(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB)
case old @ BVComparison(op, a, b, signed) =>
val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed)
case old @ BVOp(op, a, b) =>
val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB)
case old @ BVConcat(a, b) =>
val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB)
case old @ ArrayRead(a, b) =>
val (nA, nB) = (ar(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB)
case old @ BVImplies(a, b) =>
val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB)
// ternary
case old @ BVIte(a, b, c) =>
val (nA, nB, nC) = (bv(a), bv(b), bv(c))
if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)
// n-ary
case old @ BVFunctionCall(name, args, width) =>
val nArgs = args.map {
case b: BVExpr => bv(b)
case a: ArrayExpr => ar(a)
case u: UTSymbol => u
}
val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) }
if (anyNew) BVFunctionCall(name, nArgs, width) else old
case old @ BVAnd(terms) =>
val nTerms = terms.map(bv)
val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) }
if (anyNew) BVAnd(nTerms) else old
case old @ BVOr(terms) =>
val nTerms = terms.map(bv)
val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) }
if (anyNew) BVOr(nTerms) else old
}
/** maps bv/ar over subexpressions of expr and returns expr with the results replaced */
def mapExpr(expr: ArrayExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): ArrayExpr = expr match {
case old: ArraySymbol => old
case old @ ArrayConstant(e, indexWidth) => val n = bv(e); if (n.eq(e)) old else ArrayConstant(n, indexWidth)
case old @ ArrayStore(a, b, c) =>
val (nA, nB, nC) = (ar(a), bv(b), bv(c))
if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC)
case old @ ArrayIte(a, b, c) =>
val (nA, nB, nC) = (bv(a), ar(b), ar(c))
if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC)
case old @ ArrayFunctionCall(name, args, indexWidth, dataWidth) =>
val nArgs = args.map {
case b: BVExpr => bv(b)
case a: ArrayExpr => ar(a)
case u: UTSymbol => u
}
val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) }
if (anyNew) ArrayFunctionCall(name, nArgs, indexWidth, dataWidth) else old
}
}
|