1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
// See LICENSE for license details.
// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
package firrtl.backends.experimental.smt
/** Similar to the mapExpr and foreachExpr methods of the firrtl ir nodes, but external to the case classes */
private object SMTExprVisitor {
type ArrayFun = ArrayExpr => ArrayExpr
type BVFun = BVExpr => BVExpr
def map[T <: SMTExpr](bv: BVFun, ar: ArrayFun)(e: T): T = e match {
case b: BVExpr => map(b, bv, ar).asInstanceOf[T]
case a: ArrayExpr => map(a, bv, ar).asInstanceOf[T]
}
def map[T <: SMTExpr](f: SMTExpr => SMTExpr)(e: T): T =
map(b => f(b).asInstanceOf[BVExpr], a => f(a).asInstanceOf[ArrayExpr])(e)
private def map(e: BVExpr, bv: BVFun, ar: ArrayFun): BVExpr = e match {
// nullary
case old : BVLiteral => bv(old)
case old : BVSymbol => bv(old)
case old : BVRawExpr => bv(old)
// unary
case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVExtend(n, by, signed))
case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVSlice(n, hi, lo))
case old @ BVNot(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNot(n))
case old @ BVNegate(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNegate(n))
case old @ BVReduceAnd(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceAnd(n))
case old @ BVReduceOr(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceOr(n))
case old @ BVReduceXor(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceXor(n))
// binary
case old @ BVImplies(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
bv(if(nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB))
case old @ BVEqual(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
bv(if(nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB))
case old @ ArrayEqual(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
bv(if(nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB))
case old @ BVComparison(op, a, b, signed) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
bv(if(nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed))
case old @ BVOp(op, a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
bv(if(nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB))
case old @ BVConcat(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
bv(if(nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB))
case old @ ArrayRead(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
bv(if(nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB))
// ternary
case old @ BVIte(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
bv(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC))
}
private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match {
case old : ArrayRawExpr => ar(old)
case old : ArraySymbol => ar(old)
case old @ ArrayConstant(e, indexWidth) =>
val n = map(e, bv, ar) ; ar(if(n.eq(e)) old else ArrayConstant(n, indexWidth))
case old @ ArrayStore(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC))
case old @ ArrayIte(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC))
}
}
|