summaryrefslogtreecommitdiff
path: root/src/ast_util.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/ast_util.ml')
-rw-r--r--src/ast_util.ml113
1 files changed, 113 insertions, 0 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index f99ebbe5..ff86970a 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -935,3 +935,116 @@ let append_ast (Defs ast1) (Defs ast2) = Defs (ast1 @ ast2)
let type_union_id (Tu_aux (aux, _)) = match aux with
| Tu_id id -> id
| Tu_ty_id (_, id) -> id
+
+
+let rec pat_ids (P_aux (pat_aux, _)) =
+ match pat_aux with
+ | P_lit _ | P_wild -> IdSet.empty
+ | P_id id -> IdSet.singleton id
+ | P_as (pat, id) -> IdSet.add id (pat_ids pat)
+ | P_var (pat, _) | P_typ (_, pat) -> pat_ids pat
+ | P_app (_, pats) | P_tup pats | P_vector pats | P_vector_concat pats | P_list pats ->
+ List.fold_right IdSet.union (List.map pat_ids pats) IdSet.empty
+ | P_cons (pat1, pat2) ->
+ IdSet.union (pat_ids pat1) (pat_ids pat2)
+ | P_record (fpats, _) ->
+ List.fold_right IdSet.union (List.map fpat_ids fpats) IdSet.empty
+and fpat_ids (FP_aux (FP_Fpat (_, pat), _)) = pat_ids pat
+
+let rec subst id value (E_aux (e_aux, annot) as exp) =
+ let wrap e_aux = E_aux (e_aux, annot) in
+ let e_aux = match e_aux with
+ | E_block exps -> E_block (List.map (subst id value) exps)
+ | E_nondet exps -> E_nondet (List.map (subst id value) exps)
+ | E_id id' -> if Id.compare id id' = 0 then unaux_exp value else E_id id'
+ | E_lit lit -> E_lit lit
+ | E_cast (typ, exp) -> E_cast (typ, subst id value exp)
+
+ | E_app (fn, exps) -> E_app (fn, List.map (subst id value) exps)
+ | E_app_infix (exp1, op, exp2) -> E_app_infix (subst id value exp1, op, subst id value exp2)
+
+ | E_tuple exps -> E_tuple (List.map (subst id value) exps)
+
+ | E_if (cond, then_exp, else_exp) ->
+ E_if (subst id value cond, subst id value then_exp, subst id value else_exp)
+
+ | E_loop (loop, cond, body) ->
+ E_loop (loop, subst id value cond, subst id value body)
+ | E_for (id', exp1, exp2, exp3, order, body) when Id.compare id id' = 0 ->
+ E_for (id', exp1, exp2, exp3, order, body)
+ | E_for (id', exp1, exp2, exp3, order, body) ->
+ E_for (id', subst id value exp1, subst id value exp2, subst id value exp3, order, subst id value body)
+
+ | E_vector exps -> E_vector (List.map (subst id value) exps)
+ | E_vector_access (exp1, exp2) -> E_vector_access (subst id value exp1, subst id value exp2)
+ | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (subst id value exp1, subst id value exp2, subst id value exp3)
+ | E_vector_update (exp1, exp2, exp3) -> E_vector_update (subst id value exp1, subst id value exp2, subst id value exp3)
+ | E_vector_update_subrange (exp1, exp2, exp3, exp4) ->
+ E_vector_update_subrange (subst id value exp1, subst id value exp2, subst id value exp3, subst id value exp4)
+ | E_vector_append (exp1, exp2) -> E_vector_append (subst id value exp1, subst id value exp2)
+
+ | E_list exps -> E_list (List.map (subst id value) exps)
+ | E_cons (exp1, exp2) -> E_cons (subst id value exp1, subst id value exp2)
+
+ | E_record fexps -> E_record (subst_fexps id value fexps)
+ | E_record_update (exp, fexps) -> E_record_update (subst id value exp, subst_fexps id value fexps)
+ | E_field (exp, id') -> E_field (subst id value exp, id')
+
+ | E_case (exp, pexps) ->
+ E_case (subst id value exp, List.map (subst_pexp id value) pexps)
+
+ | E_let (LB_aux (LB_val (pat, bind), lb_annot), body) ->
+ E_let (LB_aux (LB_val (pat, subst id value bind), lb_annot),
+ if IdSet.mem id (pat_ids pat) then body else subst id value body)
+
+ | E_assign (lexp, exp) -> E_assign (subst_lexp id value lexp, subst id value exp) (* Shadowing... *)
+
+ (* Should be re-written *)
+ | E_sizeof nexp -> E_sizeof nexp
+ | E_constraint nc -> E_constraint nc
+
+ | E_return exp -> E_return (subst id value exp)
+ | E_exit exp -> E_exit (subst id value exp)
+ (* Not sure about this, but id should always be immutable while id' must be mutable so should be ok. *)
+ | E_ref id' -> E_ref id'
+ | E_throw exp -> E_throw (subst id value exp)
+
+ | E_try (exp, pexps) ->
+ E_try (subst id value exp, List.map (subst_pexp id value) pexps)
+
+ | E_assert (exp1, exp2) -> E_assert (subst id value exp1, subst id value exp2)
+
+ | E_internal_value v -> E_internal_value v
+ | _ -> failwith ("subst " ^ string_of_exp exp)
+ in
+ wrap e_aux
+
+and subst_pexp id value (Pat_aux (pexp_aux, annot)) =
+ let pexp_aux = match pexp_aux with
+ | Pat_exp (pat, exp) when IdSet.mem id (pat_ids pat) -> Pat_exp (pat, exp)
+ | Pat_exp (pat, exp) -> Pat_exp (pat, subst id value exp)
+ | Pat_when (pat, guard, exp) when IdSet.mem id (pat_ids pat) -> Pat_when (pat, guard, exp)
+ | Pat_when (pat, guard, exp) -> Pat_when (pat, subst id value guard, subst id value exp)
+ in
+ Pat_aux (pexp_aux, annot)
+
+
+and subst_fexps id value (FES_aux (FES_Fexps (fexps, flag), annot)) =
+ FES_aux (FES_Fexps (List.map (subst_fexp id value) fexps, flag), annot)
+
+and subst_fexp id value (FE_aux (FE_Fexp (id', exp), annot)) =
+ FE_aux (FE_Fexp (id', subst id value exp), annot)
+
+and subst_lexp id value (LEXP_aux (lexp_aux, annot) as lexp) =
+ let wrap lexp_aux = LEXP_aux (lexp_aux, annot) in
+ let lexp_aux = match lexp_aux with
+ | LEXP_deref exp -> LEXP_deref (subst id value exp)
+ | LEXP_id id' -> LEXP_id id'
+ | LEXP_memory (f, exps) -> LEXP_memory (f, List.map (subst id value) exps)
+ | LEXP_cast (typ, id') -> LEXP_cast (typ, id')
+ | LEXP_tup lexps -> LEXP_tup (List.map (subst_lexp id value) lexps)
+ | LEXP_vector (lexp, exp) -> LEXP_vector (subst_lexp id value lexp, subst id value exp)
+ | LEXP_vector_range (lexp, exp1, exp2) -> LEXP_vector_range (subst_lexp id value lexp, subst id value exp1, subst id value exp2)
+ | LEXP_field (lexp, id') -> LEXP_field (subst_lexp id value lexp, id')
+ in
+ wrap lexp_aux