summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml113
-rw-r--r--src/ast_util.mli3
-rw-r--r--src/interpreter.ml114
-rw-r--r--src/rewrites.ml4
4 files changed, 119 insertions, 115 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index f99ebbe5..ff86970a 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -935,3 +935,116 @@ let append_ast (Defs ast1) (Defs ast2) = Defs (ast1 @ ast2)
let type_union_id (Tu_aux (aux, _)) = match aux with
| Tu_id id -> id
| Tu_ty_id (_, id) -> id
+
+
+let rec pat_ids (P_aux (pat_aux, _)) =
+ match pat_aux with
+ | P_lit _ | P_wild -> IdSet.empty
+ | P_id id -> IdSet.singleton id
+ | P_as (pat, id) -> IdSet.add id (pat_ids pat)
+ | P_var (pat, _) | P_typ (_, pat) -> pat_ids pat
+ | P_app (_, pats) | P_tup pats | P_vector pats | P_vector_concat pats | P_list pats ->
+ List.fold_right IdSet.union (List.map pat_ids pats) IdSet.empty
+ | P_cons (pat1, pat2) ->
+ IdSet.union (pat_ids pat1) (pat_ids pat2)
+ | P_record (fpats, _) ->
+ List.fold_right IdSet.union (List.map fpat_ids fpats) IdSet.empty
+and fpat_ids (FP_aux (FP_Fpat (_, pat), _)) = pat_ids pat
+
+let rec subst id value (E_aux (e_aux, annot) as exp) =
+ let wrap e_aux = E_aux (e_aux, annot) in
+ let e_aux = match e_aux with
+ | E_block exps -> E_block (List.map (subst id value) exps)
+ | E_nondet exps -> E_nondet (List.map (subst id value) exps)
+ | E_id id' -> if Id.compare id id' = 0 then unaux_exp value else E_id id'
+ | E_lit lit -> E_lit lit
+ | E_cast (typ, exp) -> E_cast (typ, subst id value exp)
+
+ | E_app (fn, exps) -> E_app (fn, List.map (subst id value) exps)
+ | E_app_infix (exp1, op, exp2) -> E_app_infix (subst id value exp1, op, subst id value exp2)
+
+ | E_tuple exps -> E_tuple (List.map (subst id value) exps)
+
+ | E_if (cond, then_exp, else_exp) ->
+ E_if (subst id value cond, subst id value then_exp, subst id value else_exp)
+
+ | E_loop (loop, cond, body) ->
+ E_loop (loop, subst id value cond, subst id value body)
+ | E_for (id', exp1, exp2, exp3, order, body) when Id.compare id id' = 0 ->
+ E_for (id', exp1, exp2, exp3, order, body)
+ | E_for (id', exp1, exp2, exp3, order, body) ->
+ E_for (id', subst id value exp1, subst id value exp2, subst id value exp3, order, subst id value body)
+
+ | E_vector exps -> E_vector (List.map (subst id value) exps)
+ | E_vector_access (exp1, exp2) -> E_vector_access (subst id value exp1, subst id value exp2)
+ | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (subst id value exp1, subst id value exp2, subst id value exp3)
+ | E_vector_update (exp1, exp2, exp3) -> E_vector_update (subst id value exp1, subst id value exp2, subst id value exp3)
+ | E_vector_update_subrange (exp1, exp2, exp3, exp4) ->
+ E_vector_update_subrange (subst id value exp1, subst id value exp2, subst id value exp3, subst id value exp4)
+ | E_vector_append (exp1, exp2) -> E_vector_append (subst id value exp1, subst id value exp2)
+
+ | E_list exps -> E_list (List.map (subst id value) exps)
+ | E_cons (exp1, exp2) -> E_cons (subst id value exp1, subst id value exp2)
+
+ | E_record fexps -> E_record (subst_fexps id value fexps)
+ | E_record_update (exp, fexps) -> E_record_update (subst id value exp, subst_fexps id value fexps)
+ | E_field (exp, id') -> E_field (subst id value exp, id')
+
+ | E_case (exp, pexps) ->
+ E_case (subst id value exp, List.map (subst_pexp id value) pexps)
+
+ | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) ->
+ E_let (LB_aux (LB_val (pat, subst id value bind), lb_annot),
+ if IdSet.mem id (pat_ids pat) then body else subst id value body)
+
+ | E_assign (lexp, exp) -> E_assign (subst_lexp id value lexp, subst id value exp) (* Shadowing... *)
+
+ (* Should be re-written *)
+ | E_sizeof nexp -> E_sizeof nexp
+ | E_constraint nc -> E_constraint nc
+
+ | E_return exp -> E_return (subst id value exp)
+ | E_exit exp -> E_exit (subst id value exp)
+ (* Not sure about this, but id should always be immutable while id' must be mutable so should be ok. *)
+ | E_ref id' -> E_ref id'
+ | E_throw exp -> E_throw (subst id value exp)
+
+ | E_try (exp, pexps) ->
+ E_try (subst id value exp, List.map (subst_pexp id value) pexps)
+
+ | E_assert (exp1, exp2) -> E_assert (subst id value exp1, subst id value exp2)
+
+ | E_internal_value v -> E_internal_value v
+ | _ -> failwith ("subst " ^ string_of_exp exp)
+ in
+ wrap e_aux
+
+and subst_pexp id value (Pat_aux (pexp_aux, annot)) =
+ let pexp_aux = match pexp_aux with
+ | Pat_exp (pat, exp) when IdSet.mem id (pat_ids pat) -> Pat_exp (pat, exp)
+ | Pat_exp (pat, exp) -> Pat_exp (pat, subst id value exp)
+ | Pat_when (pat, guard, exp) when IdSet.mem id (pat_ids pat) -> Pat_when (pat, guard, exp)
+ | Pat_when (pat, guard, exp) -> Pat_when (pat, subst id value guard, subst id value exp)
+ in
+ Pat_aux (pexp_aux, annot)
+
+
+and subst_fexps id value (FES_aux (FES_Fexps (fexps, flag), annot)) =
+ FES_aux (FES_Fexps (List.map (subst_fexp id value) fexps, flag), annot)
+
+and subst_fexp id value (FE_aux (FE_Fexp (id', exp), annot)) =
+ FE_aux (FE_Fexp (id', subst id value exp), annot)
+
+and subst_lexp id value (LEXP_aux (lexp_aux, annot) as lexp) =
+ let wrap lexp_aux = LEXP_aux (lexp_aux, annot) in
+ let lexp_aux = match lexp_aux with
+ | LEXP_deref exp -> LEXP_deref (subst id value exp)
+ | LEXP_id id' -> LEXP_id id'
+ | LEXP_memory (f, exps) -> LEXP_memory (f, List.map (subst id value) exps)
+ | LEXP_cast (typ, id') -> LEXP_cast (typ, id')
+ | LEXP_tup lexps -> LEXP_tup (List.map (subst_lexp id value) lexps)
+ | LEXP_vector (lexp, exp) -> LEXP_vector (subst_lexp id value lexp, subst id value exp)
+ | LEXP_vector_range (lexp, exp1, exp2) -> LEXP_vector_range (subst_lexp id value lexp, subst id value exp1, subst id value exp2)
+ | LEXP_field (lexp, id') -> LEXP_field (subst_lexp id value lexp, id')
+ in
+ wrap lexp_aux
diff --git a/src/ast_util.mli b/src/ast_util.mli
index 63f7658f..a6665332 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -303,3 +303,6 @@ val split_defs : ('a def -> bool) -> 'a defs -> ('a defs * 'a def * 'a defs) opt
val append_ast : 'a defs -> 'a defs -> 'a defs
val type_union_id : type_union -> id
+
+val pat_ids : 'a pat -> IdSet.t
+val subst : id -> 'a exp -> 'a exp -> 'a exp
diff --git a/src/interpreter.ml b/src/interpreter.ml
index a95ab8c3..e39c90f0 100644
--- a/src/interpreter.ml
+++ b/src/interpreter.ml
@@ -196,107 +196,11 @@ let assertion_failed msg = Yield (Assertion_failed msg)
let liftM2 f m1 m2 = m1 >>= fun x -> m2 >>= fun y -> return (f x y)
-let rec pat_ids (P_aux (pat_aux, _)) =
- match pat_aux with
- | P_lit _ | P_wild -> IdSet.empty
- | P_id id -> IdSet.singleton id
- | P_as (pat, id) -> IdSet.add id (pat_ids pat)
- | P_var (pat, _) | P_typ (_, pat) -> pat_ids pat
- | P_app (_, pats) | P_tup pats | P_vector pats | P_vector_concat pats | P_list pats ->
- List.fold_right IdSet.union (List.map pat_ids pats) IdSet.empty
- | P_cons (pat1, pat2) ->
- IdSet.union (pat_ids pat1) (pat_ids pat2)
- | P_record (fpats, _) ->
- List.fold_right IdSet.union (List.map fpat_ids fpats) IdSet.empty
-and fpat_ids (FP_aux (FP_Fpat (_, pat), _)) = pat_ids pat
-
let letbind_pat_ids (LB_aux (LB_val (pat, _), _)) = pat_ids pat
-let rec subst id value (E_aux (e_aux, annot) as exp) =
- let wrap e_aux = E_aux (e_aux, annot) in
- let e_aux = match e_aux with
- | E_block exps -> E_block (List.map (subst id value) exps)
- | E_nondet exps -> E_nondet (List.map (subst id value) exps)
- | E_id id' -> if Id.compare id id' = 0 then unaux_exp (exp_of_value value) else E_id id'
- | E_lit lit -> E_lit lit
- | E_cast (typ, exp) -> E_cast (typ, subst id value exp)
-
- | E_app (fn, exps) -> E_app (fn, List.map (subst id value) exps)
- | E_app_infix (exp1, op, exp2) -> E_app_infix (subst id value exp1, op, subst id value exp2)
-
- | E_tuple exps -> E_tuple (List.map (subst id value) exps)
-
- | E_if (cond, then_exp, else_exp) ->
- E_if (subst id value cond, subst id value then_exp, subst id value else_exp)
-
- | E_loop (loop, cond, body) ->
- E_loop (loop, subst id value cond, subst id value body)
- | E_for (id', exp1, exp2, exp3, order, body) when Id.compare id id' = 0 ->
- E_for (id', exp1, exp2, exp3, order, body)
- | E_for (id', exp1, exp2, exp3, order, body) ->
- E_for (id', subst id value exp1, subst id value exp2, subst id value exp3, order, subst id value body)
-
- | E_vector exps -> E_vector (List.map (subst id value) exps)
- | E_vector_access (exp1, exp2) -> E_vector_access (subst id value exp1, subst id value exp2)
- | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (subst id value exp1, subst id value exp2, subst id value exp3)
- | E_vector_update (exp1, exp2, exp3) -> E_vector_update (subst id value exp1, subst id value exp2, subst id value exp3)
- | E_vector_update_subrange (exp1, exp2, exp3, exp4) ->
- E_vector_update_subrange (subst id value exp1, subst id value exp2, subst id value exp3, subst id value exp4)
- | E_vector_append (exp1, exp2) -> E_vector_append (subst id value exp1, subst id value exp2)
-
- | E_list exps -> E_list (List.map (subst id value) exps)
- | E_cons (exp1, exp2) -> E_cons (subst id value exp1, subst id value exp2)
-
- | E_record fexps -> E_record (subst_fexps id value fexps)
- | E_record_update (exp, fexps) -> E_record_update (subst id value exp, subst_fexps id value fexps)
- | E_field (exp, id') -> E_field (subst id value exp, id')
-
- | E_case (exp, pexps) ->
- E_case (subst id value exp, List.map (subst_pexp id value) pexps)
-
- | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) ->
- E_let (LB_aux (LB_val (pat, subst id value bind), lb_annot),
- if IdSet.mem id (pat_ids pat) then body else subst id value body)
-
- | E_assign (lexp, exp) -> E_assign (subst_lexp id value lexp, subst id value exp) (* Shadowing... *)
+let subst id value exp = Ast_util.subst id (exp_of_value value) exp
- (* Should be re-written *)
- | E_sizeof nexp -> E_sizeof nexp
- | E_constraint nc -> E_constraint nc
-
- | E_return exp -> E_return (subst id value exp)
- | E_exit exp -> E_exit (subst id value exp)
- (* Not sure about this, but id should always be immutable while id' must be mutable so should be ok. *)
- | E_ref id' -> E_ref id'
- | E_throw exp -> E_throw (subst id value exp)
-
- | E_try (exp, pexps) ->
- E_try (subst id value exp, List.map (subst_pexp id value) pexps)
-
- | E_assert (exp1, exp2) -> E_assert (subst id value exp1, subst id value exp2)
-
- | E_internal_value v -> E_internal_value v
- | _ -> failwith ("subst " ^ string_of_exp exp)
- in
- wrap e_aux
-
-and subst_pexp id value (Pat_aux (pexp_aux, annot)) =
- let pexp_aux = match pexp_aux with
- | Pat_exp (pat, exp) when IdSet.mem id (pat_ids pat) -> Pat_exp (pat, exp)
- | Pat_exp (pat, exp) -> Pat_exp (pat, subst id value exp)
- | Pat_when (pat, guard, exp) when IdSet.mem id (pat_ids pat) -> Pat_when (pat, guard, exp)
- | Pat_when (pat, guard, exp) -> Pat_when (pat, subst id value guard, subst id value exp)
- in
- Pat_aux (pexp_aux, annot)
-
-
-and subst_fexps id value (FES_aux (FES_Fexps (fexps, flag), annot)) =
- FES_aux (FES_Fexps (List.map (subst_fexp id value) fexps, flag), annot)
-
-and subst_fexp id value (FE_aux (FE_Fexp (id', exp), annot)) =
- FE_aux (FE_Fexp (id', subst id value exp), annot)
-
-and local_variable id lstate gstate =
+let local_variable id lstate gstate =
try
match Bindings.find id lstate.locals with
| Var_value v -> exp_of_value v
@@ -306,20 +210,6 @@ and local_variable id lstate gstate =
with
| Not_found -> failwith ("Could not find local variable " ^ string_of_id id)
-and subst_lexp id value (LEXP_aux (lexp_aux, annot) as lexp) =
- let wrap lexp_aux = LEXP_aux (lexp_aux, annot) in
- let lexp_aux = match lexp_aux with
- | LEXP_deref exp -> LEXP_deref (subst id value exp)
- | LEXP_id id' -> LEXP_id id'
- | LEXP_memory (f, exps) -> LEXP_memory (f, List.map (subst id value) exps)
- | LEXP_cast (typ, id') -> LEXP_cast (typ, id')
- | LEXP_tup lexps -> LEXP_tup (List.map (subst_lexp id value) lexps)
- | LEXP_vector (lexp, exp) -> LEXP_vector (subst_lexp id value lexp, subst id value exp)
- | LEXP_vector_range (lexp, exp1, exp2) -> LEXP_vector_range (subst_lexp id value lexp, subst id value exp1, subst id value exp2)
- | LEXP_field (lexp, id') -> LEXP_field (subst_lexp id value lexp, id')
- in
- wrap lexp_aux
-
(**************************************************************************)
(* 2. Expression Evaluation *)
(**************************************************************************)
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 56352a22..016bd7d9 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -995,9 +995,7 @@ let subst_id_pat pat (id1,id2) =
fold_pat {id_pat_alg with p_id = p_id} pat
let subst_id_exp exp (id1,id2) =
- (* TODO Don't substitute bound occurrences inside let expressions etc *)
- let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in
- fold_exp {id_exp_alg with e_id = e_id} exp
+ Ast_util.subst (Id_aux (id1,Parse_ast.Unknown)) (E_aux (E_id (Id_aux (id2,Parse_ast.Unknown)),(Parse_ast.Unknown,None))) exp
let rec pat_to_exp (P_aux (pat,(l,annot))) =
let rewrap e = E_aux (e,(l,annot)) in