diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 113 | ||||
| -rw-r--r-- | src/ast_util.mli | 3 | ||||
| -rw-r--r-- | src/interpreter.ml | 114 | ||||
| -rw-r--r-- | src/rewrites.ml | 4 |
4 files changed, 119 insertions, 115 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index f99ebbe5..ff86970a 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -935,3 +935,116 @@ let append_ast (Defs ast1) (Defs ast2) = Defs (ast1 @ ast2) let type_union_id (Tu_aux (aux, _)) = match aux with | Tu_id id -> id | Tu_ty_id (_, id) -> id + + +let rec pat_ids (P_aux (pat_aux, _)) = + match pat_aux with + | P_lit _ | P_wild -> IdSet.empty + | P_id id -> IdSet.singleton id + | P_as (pat, id) -> IdSet.add id (pat_ids pat) + | P_var (pat, _) | P_typ (_, pat) -> pat_ids pat + | P_app (_, pats) | P_tup pats | P_vector pats | P_vector_concat pats | P_list pats -> + List.fold_right IdSet.union (List.map pat_ids pats) IdSet.empty + | P_cons (pat1, pat2) -> + IdSet.union (pat_ids pat1) (pat_ids pat2) + | P_record (fpats, _) -> + List.fold_right IdSet.union (List.map fpat_ids fpats) IdSet.empty +and fpat_ids (FP_aux (FP_Fpat (_, pat), _)) = pat_ids pat + +let rec subst id value (E_aux (e_aux, annot) as exp) = + let wrap e_aux = E_aux (e_aux, annot) in + let e_aux = match e_aux with + | E_block exps -> E_block (List.map (subst id value) exps) + | E_nondet exps -> E_nondet (List.map (subst id value) exps) + | E_id id' -> if Id.compare id id' = 0 then unaux_exp value else E_id id' + | E_lit lit -> E_lit lit + | E_cast (typ, exp) -> E_cast (typ, subst id value exp) + + | E_app (fn, exps) -> E_app (fn, List.map (subst id value) exps) + | E_app_infix (exp1, op, exp2) -> E_app_infix (subst id value exp1, op, subst id value exp2) + + | E_tuple exps -> E_tuple (List.map (subst id value) exps) + + | E_if (cond, then_exp, else_exp) -> + E_if (subst id value cond, subst id value then_exp, subst id value else_exp) + + | E_loop (loop, cond, body) -> + E_loop (loop, subst id value cond, subst id value body) + | E_for (id', exp1, exp2, exp3, order, body) when Id.compare id id' = 0 -> + E_for (id', exp1, exp2, exp3, order, body) + | E_for (id', exp1, exp2, exp3, order, body) -> + E_for (id', subst id value exp1, subst id value exp2, subst id value exp3, order, subst id value body) + + | E_vector exps -> E_vector (List.map (subst id value) exps) + | E_vector_access (exp1, exp2) -> E_vector_access (subst id value exp1, subst id value exp2) + | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (subst id value exp1, subst id value exp2, subst id value exp3) + | E_vector_update (exp1, exp2, exp3) -> E_vector_update (subst id value exp1, subst id value exp2, subst id value exp3) + | E_vector_update_subrange (exp1, exp2, exp3, exp4) -> + E_vector_update_subrange (subst id value exp1, subst id value exp2, subst id value exp3, subst id value exp4) + | E_vector_append (exp1, exp2) -> E_vector_append (subst id value exp1, subst id value exp2) + + | E_list exps -> E_list (List.map (subst id value) exps) + | E_cons (exp1, exp2) -> E_cons (subst id value exp1, subst id value exp2) + + | E_record fexps -> E_record (subst_fexps id value fexps) + | E_record_update (exp, fexps) -> E_record_update (subst id value exp, subst_fexps id value fexps) + | E_field (exp, id') -> E_field (subst id value exp, id') + + | E_case (exp, pexps) -> + E_case (subst id value exp, List.map (subst_pexp id value) pexps) + + | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) -> + E_let (LB_aux (LB_val (pat, subst id value bind), lb_annot), + if IdSet.mem id (pat_ids pat) then body else subst id value body) + + | E_assign (lexp, exp) -> E_assign (subst_lexp id value lexp, subst id value exp) (* Shadowing... *) + + (* Should be re-written *) + | E_sizeof nexp -> E_sizeof nexp + | E_constraint nc -> E_constraint nc + + | E_return exp -> E_return (subst id value exp) + | E_exit exp -> E_exit (subst id value exp) + (* Not sure about this, but id should always be immutable while id' must be mutable so should be ok. *) + | E_ref id' -> E_ref id' + | E_throw exp -> E_throw (subst id value exp) + + | E_try (exp, pexps) -> + E_try (subst id value exp, List.map (subst_pexp id value) pexps) + + | E_assert (exp1, exp2) -> E_assert (subst id value exp1, subst id value exp2) + + | E_internal_value v -> E_internal_value v + | _ -> failwith ("subst " ^ string_of_exp exp) + in + wrap e_aux + +and subst_pexp id value (Pat_aux (pexp_aux, annot)) = + let pexp_aux = match pexp_aux with + | Pat_exp (pat, exp) when IdSet.mem id (pat_ids pat) -> Pat_exp (pat, exp) + | Pat_exp (pat, exp) -> Pat_exp (pat, subst id value exp) + | Pat_when (pat, guard, exp) when IdSet.mem id (pat_ids pat) -> Pat_when (pat, guard, exp) + | Pat_when (pat, guard, exp) -> Pat_when (pat, subst id value guard, subst id value exp) + in + Pat_aux (pexp_aux, annot) + + +and subst_fexps id value (FES_aux (FES_Fexps (fexps, flag), annot)) = + FES_aux (FES_Fexps (List.map (subst_fexp id value) fexps, flag), annot) + +and subst_fexp id value (FE_aux (FE_Fexp (id', exp), annot)) = + FE_aux (FE_Fexp (id', subst id value exp), annot) + +and subst_lexp id value (LEXP_aux (lexp_aux, annot) as lexp) = + let wrap lexp_aux = LEXP_aux (lexp_aux, annot) in + let lexp_aux = match lexp_aux with + | LEXP_deref exp -> LEXP_deref (subst id value exp) + | LEXP_id id' -> LEXP_id id' + | LEXP_memory (f, exps) -> LEXP_memory (f, List.map (subst id value) exps) + | LEXP_cast (typ, id') -> LEXP_cast (typ, id') + | LEXP_tup lexps -> LEXP_tup (List.map (subst_lexp id value) lexps) + | LEXP_vector (lexp, exp) -> LEXP_vector (subst_lexp id value lexp, subst id value exp) + | LEXP_vector_range (lexp, exp1, exp2) -> LEXP_vector_range (subst_lexp id value lexp, subst id value exp1, subst id value exp2) + | LEXP_field (lexp, id') -> LEXP_field (subst_lexp id value lexp, id') + in + wrap lexp_aux diff --git a/src/ast_util.mli b/src/ast_util.mli index 63f7658f..a6665332 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -303,3 +303,6 @@ val split_defs : ('a def -> bool) -> 'a defs -> ('a defs * 'a def * 'a defs) opt val append_ast : 'a defs -> 'a defs -> 'a defs val type_union_id : type_union -> id + +val pat_ids : 'a pat -> IdSet.t +val subst : id -> 'a exp -> 'a exp -> 'a exp diff --git a/src/interpreter.ml b/src/interpreter.ml index a95ab8c3..e39c90f0 100644 --- a/src/interpreter.ml +++ b/src/interpreter.ml @@ -196,107 +196,11 @@ let assertion_failed msg = Yield (Assertion_failed msg) let liftM2 f m1 m2 = m1 >>= fun x -> m2 >>= fun y -> return (f x y) -let rec pat_ids (P_aux (pat_aux, _)) = - match pat_aux with - | P_lit _ | P_wild -> IdSet.empty - | P_id id -> IdSet.singleton id - | P_as (pat, id) -> IdSet.add id (pat_ids pat) - | P_var (pat, _) | P_typ (_, pat) -> pat_ids pat - | P_app (_, pats) | P_tup pats | P_vector pats | P_vector_concat pats | P_list pats -> - List.fold_right IdSet.union (List.map pat_ids pats) IdSet.empty - | P_cons (pat1, pat2) -> - IdSet.union (pat_ids pat1) (pat_ids pat2) - | P_record (fpats, _) -> - List.fold_right IdSet.union (List.map fpat_ids fpats) IdSet.empty -and fpat_ids (FP_aux (FP_Fpat (_, pat), _)) = pat_ids pat - let letbind_pat_ids (LB_aux (LB_val (pat, _), _)) = pat_ids pat -let rec subst id value (E_aux (e_aux, annot) as exp) = - let wrap e_aux = E_aux (e_aux, annot) in - let e_aux = match e_aux with - | E_block exps -> E_block (List.map (subst id value) exps) - | E_nondet exps -> E_nondet (List.map (subst id value) exps) - | E_id id' -> if Id.compare id id' = 0 then unaux_exp (exp_of_value value) else E_id id' - | E_lit lit -> E_lit lit - | E_cast (typ, exp) -> E_cast (typ, subst id value exp) - - | E_app (fn, exps) -> E_app (fn, List.map (subst id value) exps) - | E_app_infix (exp1, op, exp2) -> E_app_infix (subst id value exp1, op, subst id value exp2) - - | E_tuple exps -> E_tuple (List.map (subst id value) exps) - - | E_if (cond, then_exp, else_exp) -> - E_if (subst id value cond, subst id value then_exp, subst id value else_exp) - - | E_loop (loop, cond, body) -> - E_loop (loop, subst id value cond, subst id value body) - | E_for (id', exp1, exp2, exp3, order, body) when Id.compare id id' = 0 -> - E_for (id', exp1, exp2, exp3, order, body) - | E_for (id', exp1, exp2, exp3, order, body) -> - E_for (id', subst id value exp1, subst id value exp2, subst id value exp3, order, subst id value body) - - | E_vector exps -> E_vector (List.map (subst id value) exps) - | E_vector_access (exp1, exp2) -> E_vector_access (subst id value exp1, subst id value exp2) - | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (subst id value exp1, subst id value exp2, subst id value exp3) - | E_vector_update (exp1, exp2, exp3) -> E_vector_update (subst id value exp1, subst id value exp2, subst id value exp3) - | E_vector_update_subrange (exp1, exp2, exp3, exp4) -> - E_vector_update_subrange (subst id value exp1, subst id value exp2, subst id value exp3, subst id value exp4) - | E_vector_append (exp1, exp2) -> E_vector_append (subst id value exp1, subst id value exp2) - - | E_list exps -> E_list (List.map (subst id value) exps) - | E_cons (exp1, exp2) -> E_cons (subst id value exp1, subst id value exp2) - - | E_record fexps -> E_record (subst_fexps id value fexps) - | E_record_update (exp, fexps) -> E_record_update (subst id value exp, subst_fexps id value fexps) - | E_field (exp, id') -> E_field (subst id value exp, id') - - | E_case (exp, pexps) -> - E_case (subst id value exp, List.map (subst_pexp id value) pexps) - - | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) -> - E_let (LB_aux (LB_val (pat, subst id value bind), lb_annot), - if IdSet.mem id (pat_ids pat) then body else subst id value body) - - | E_assign (lexp, exp) -> E_assign (subst_lexp id value lexp, subst id value exp) (* Shadowing... *) +let subst id value exp = Ast_util.subst id (exp_of_value value) exp - (* Should be re-written *) - | E_sizeof nexp -> E_sizeof nexp - | E_constraint nc -> E_constraint nc - - | E_return exp -> E_return (subst id value exp) - | E_exit exp -> E_exit (subst id value exp) - (* Not sure about this, but id should always be immutable while id' must be mutable so should be ok. *) - | E_ref id' -> E_ref id' - | E_throw exp -> E_throw (subst id value exp) - - | E_try (exp, pexps) -> - E_try (subst id value exp, List.map (subst_pexp id value) pexps) - - | E_assert (exp1, exp2) -> E_assert (subst id value exp1, subst id value exp2) - - | E_internal_value v -> E_internal_value v - | _ -> failwith ("subst " ^ string_of_exp exp) - in - wrap e_aux - -and subst_pexp id value (Pat_aux (pexp_aux, annot)) = - let pexp_aux = match pexp_aux with - | Pat_exp (pat, exp) when IdSet.mem id (pat_ids pat) -> Pat_exp (pat, exp) - | Pat_exp (pat, exp) -> Pat_exp (pat, subst id value exp) - | Pat_when (pat, guard, exp) when IdSet.mem id (pat_ids pat) -> Pat_when (pat, guard, exp) - | Pat_when (pat, guard, exp) -> Pat_when (pat, subst id value guard, subst id value exp) - in - Pat_aux (pexp_aux, annot) - - -and subst_fexps id value (FES_aux (FES_Fexps (fexps, flag), annot)) = - FES_aux (FES_Fexps (List.map (subst_fexp id value) fexps, flag), annot) - -and subst_fexp id value (FE_aux (FE_Fexp (id', exp), annot)) = - FE_aux (FE_Fexp (id', subst id value exp), annot) - -and local_variable id lstate gstate = +let local_variable id lstate gstate = try match Bindings.find id lstate.locals with | Var_value v -> exp_of_value v @@ -306,20 +210,6 @@ and local_variable id lstate gstate = with | Not_found -> failwith ("Could not find local variable " ^ string_of_id id) -and subst_lexp id value (LEXP_aux (lexp_aux, annot) as lexp) = - let wrap lexp_aux = LEXP_aux (lexp_aux, annot) in - let lexp_aux = match lexp_aux with - | LEXP_deref exp -> LEXP_deref (subst id value exp) - | LEXP_id id' -> LEXP_id id' - | LEXP_memory (f, exps) -> LEXP_memory (f, List.map (subst id value) exps) - | LEXP_cast (typ, id') -> LEXP_cast (typ, id') - | LEXP_tup lexps -> LEXP_tup (List.map (subst_lexp id value) lexps) - | LEXP_vector (lexp, exp) -> LEXP_vector (subst_lexp id value lexp, subst id value exp) - | LEXP_vector_range (lexp, exp1, exp2) -> LEXP_vector_range (subst_lexp id value lexp, subst id value exp1, subst id value exp2) - | LEXP_field (lexp, id') -> LEXP_field (subst_lexp id value lexp, id') - in - wrap lexp_aux - (**************************************************************************) (* 2. Expression Evaluation *) (**************************************************************************) diff --git a/src/rewrites.ml b/src/rewrites.ml index 56352a22..016bd7d9 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -995,9 +995,7 @@ let subst_id_pat pat (id1,id2) = fold_pat {id_pat_alg with p_id = p_id} pat let subst_id_exp exp (id1,id2) = - (* TODO Don't substitute bound occurrences inside let expressions etc *) - let e_id (Id_aux (id,l)) = (if id = id1 then E_id (Id_aux (id2,l)) else E_id (Id_aux (id,l))) in - fold_exp {id_exp_alg with e_id = e_id} exp + Ast_util.subst (Id_aux (id1,Parse_ast.Unknown)) (E_aux (E_id (Id_aux (id2,Parse_ast.Unknown)),(Parse_ast.Unknown,None))) exp let rec pat_to_exp (P_aux (pat,(l,annot))) = let rewrap e = E_aux (e,(l,annot)) in |
