diff options
| author | Christopher Pulte | 2015-10-20 14:28:24 +0100 |
|---|---|---|
| committer | Christopher Pulte | 2015-10-20 14:28:24 +0100 |
| commit | 117e58ac3da5d79dab16988b693cdd0908c0bb48 (patch) | |
| tree | e75c66f2d6ede16924a02f555e34159f6a197f4a /src | |
| parent | 602adb432b158efa403959454328bc58bddca61b (diff) | |
fix a-normalisation bug
Diffstat (limited to 'src')
| -rw-r--r-- | src/pretty_print.ml | 3 | ||||
| -rw-r--r-- | src/rewriter.ml | 287 |
2 files changed, 137 insertions, 153 deletions
diff --git a/src/pretty_print.ml b/src/pretty_print.ml index 3ffc67b7..9b5173c3 100644 --- a/src/pretty_print.ml +++ b/src/pretty_print.ml @@ -1340,9 +1340,10 @@ let doc_exp_ocaml, doc_let_ocaml = string "if" ^^ space ^^ string "to_bool" ^^ parens (exp c) ^/^ string "then" ^^ space ^^ (exp t) | E_if(c,t,e) -> + parens ( string "if" ^^ space ^^ string "to_bool" ^^ parens (exp c) ^/^ string "then" ^^ space ^^ group (exp t) ^/^ - string "else" ^^ space ^^ group (exp e) + string "else" ^^ space ^^ group (exp e)) | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> let var= doc_id_ocaml id in let (compare,next) = if order = Ord_inc then string "<=",string "+" else string ">=",string "-" in diff --git a/src/rewriter.ml b/src/rewriter.ml index 9aa5063d..d7bb68f4 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -913,37 +913,30 @@ let remove_blocks_exp_alg = { id_exp_alg with e_aux = e_aux } -type ('r,'a) cont = ('a -> 'r) -> 'r - -let return : 'a -> ('r,'a) cont = - fun a -> (fun f -> f a) - -let bind : ('r,'a) cont -> ('a -> ('r,'b) cont) -> ('r,'b) cont = - fun m f -> (fun k -> m (fun a -> (f a) k)) - -let (>>=) = bind - let a_normalise_counter = ref 0 let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = + (* body is a function : E_id variable -> actual body *) + let fresh_id () = let current = !a_normalise_counter in let () = a_normalise_counter := (current + 1) in Id_aux (Id ("__w" ^ string_of_int current), Parse_ast.Unknown) in - (* body is a function : E_id variable -> actual body *) let freshid = fresh_id () in let annot_var = (Parse_ast.Unknown,simple_annot (gettype v)) in let eid = E_aux (E_id freshid, annot_var) in + + let body = body eid in let annot_pat = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in let annot_lb = annot_pat in - let annot_let = (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v (body eid))) in + let annot_let = (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v body)) in let pat = P_aux (P_id freshid,annot_pat) in if effectful v - then E_aux (E_internal_plet (pat,v,body eid),annot_let) - else E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body eid),annot_let) + then E_aux (E_internal_plet (pat,v,body),annot_let) + else E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let) let rec value ((E_aux (exp_aux,_)) as exp) = not (effectful exp) && @@ -958,87 +951,76 @@ let rec value ((E_aux (exp_aux,_)) as exp) = let only_local_eff (l,(Base ((t_params,t),tag,nexps,eff,effsum,bounds))) = (l,Base ((t_params,t),tag,nexps,eff,eff,bounds)) -let rec norm_list : ('b -> ('annot exp,'b) cont) -> ('b list -> ('a exp,'b list) cont) = - fun normf l -> +let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp = match l with - | [] -> return [] - | e :: es -> - normf e >>= fun e -> - norm_list normf es >>= fun es -> - return (e :: es) - -let rec norm_exp_to_name exp : ('a exp,'a exp) cont = - norm_exp exp >>= fun exp -> - if value exp then return exp else letbind exp + | [] -> k [] + | exp :: exps -> f exp (fun exp -> mapCont f exps (fun exps -> k (exp :: exps))) + +let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = + n_exp exp (fun exp -> if value exp then k exp else letbind exp k) -and norm_exp_to_nameL (exps : 'a exp list) : ('a exp,'a exp list) cont = - norm_list norm_exp_to_name exps +and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp = + mapCont n_exp_name exps k -and norm_fexp (fexp : 'a fexp) : ('a exp,'a fexp) cont = +and n_fexp (fexp : 'a fexp) (k : 'a fexp -> 'a exp) : 'a exp = let (FE_aux (FE_Fexp (id,exp),annot)) = fexp in - norm_exp_to_name exp >>= fun exp -> - return (FE_aux (FE_Fexp (id,exp),annot)) + n_exp_name exp (fun exp -> k (FE_aux (FE_Fexp (id,exp),annot))) -and norm_fexpL (fexps : 'a fexp list) : ('a exp,'a fexp list) cont = - norm_list norm_fexp fexps - -and norm_pexpL (pexps : 'a pexp list) : ('a exp, 'a pexp list) cont = - norm_list norm_pexp pexps +and n_fexpL (fexps : 'a fexp list) (k : 'a fexp list -> 'a exp) : 'a exp = + mapCont n_fexp fexps k + +and n_pexp (new_return : bool) (pexp : 'a pexp) (k : 'a pexp -> 'a exp) : 'a exp = + let (Pat_aux (Pat_exp (pat,exp),annot)) = pexp in + k (Pat_aux (Pat_exp (pat,n_exp_term new_return exp), annot)) -and norm_exp_to_term (exp : 'a exp) : 'a exp = - let (E_aux (_,annot)) = exp in - let exp = - if effectful exp then E_aux (E_internal_return exp,annot) else exp in - norm_exp exp (fun exp -> exp) +and n_pexpL (pexps : 'a pexp list) (k : 'a pexp list -> 'a exp) : 'a exp = + let geteffs (Pat_aux (_,(_,Base (_,_,_,_,{effect = Eset effs},_)))) = effs in + let effs = {effect = Eset (List.flatten (List.map geteffs pexps))} in + mapCont (n_pexp (effectful_effs effs)) pexps k -and norm_fexps (fexps :'a fexps) : ('a exp,'a fexps) cont = +and n_fexps (fexps : 'a fexps) (k : 'a fexps -> 'a exp) : 'a exp = let (FES_aux (FES_Fexps (fexps_aux,b),annot)) = fexps in - norm_fexpL fexps_aux >>= fun fexps_aux -> - return (FES_aux (FES_Fexps (fexps_aux,b),only_local_eff annot)) - -and norm_pexp (pexp : 'a pexp) : ('a exp,'a pexp) cont = - let (Pat_aux (Pat_exp (pat,exp),annot)) = pexp in - return (Pat_aux (Pat_exp (pat,norm_exp_to_term exp), annot)) + n_fexpL fexps_aux (fun fexps_aux -> k (FES_aux (FES_Fexps (fexps_aux,b),only_local_eff annot))) -and norm_opt_default (opt_default : 'a opt_default) : ('a exp,'a opt_default) cont = +and n_opt_default (opt_default : 'a opt_default) (k : 'a opt_default -> 'a exp) : 'a exp = let (Def_val_aux (opt_default,annot)) = opt_default in match opt_default with - | Def_val_empty -> return (Def_val_aux (Def_val_empty,annot)) - | Def_val_dec exp' -> - norm_exp_to_name exp' >>= fun exp' -> - return (Def_val_aux (Def_val_dec exp',only_local_eff annot)) + | Def_val_empty -> k (Def_val_aux (Def_val_empty,annot)) + | Def_val_dec exp -> + n_exp_name exp (fun exp -> k (Def_val_aux (Def_val_dec exp,only_local_eff annot))) -and norm_lb (lb : 'a letbind) : ('a exp,'a letbind) cont = +and n_lb (lb : 'a letbind) (k : 'a letbind -> 'a exp) : 'a exp = let (LB_aux (lb,annot)) = lb in match lb with | LB_val_explicit (typ,pat,exp1) -> - norm_exp exp1 >>= fun exp1 -> - return (LB_aux (LB_val_explicit (typ,pat,exp1),only_local_eff annot)) + n_exp exp1 (fun exp1 -> k (LB_aux (LB_val_explicit (typ,pat,exp1),only_local_eff annot))) | LB_val_implicit (pat,exp1) -> - norm_exp exp1 >>= fun exp1 -> - return (LB_aux (LB_val_implicit (pat,exp1),only_local_eff annot)) + n_exp exp1 (fun exp1 -> k (LB_aux (LB_val_implicit (pat,exp1),only_local_eff annot))) -and norm_lexp (lexp : 'a lexp) : ('a exp,'a lexp) cont = +and n_lexp (lexp : 'a lexp) (k : 'a lexp -> 'a exp) : 'a exp = let (LEXP_aux (lexp_aux,annot)) = lexp in match lexp_aux with - | LEXP_id _ -> return lexp + | LEXP_id _ -> k lexp | LEXP_memory (id,es) -> - norm_exp_to_nameL es >>= fun es -> - return (LEXP_aux (LEXP_memory (id,es),only_local_eff annot)) - | LEXP_cast (typ,id) -> return (LEXP_aux (LEXP_cast (typ,id),only_local_eff annot)) + n_exp_nameL es (fun es -> k (LEXP_aux (LEXP_memory (id,es),only_local_eff annot))) + | LEXP_cast (typ,id) -> k (LEXP_aux (LEXP_cast (typ,id),only_local_eff annot)) | LEXP_vector (lexp,id) -> - norm_lexp lexp >>= fun lexp -> - return (LEXP_aux (LEXP_vector (lexp,id),only_local_eff annot)) + n_lexp lexp (fun lexp -> k (LEXP_aux (LEXP_vector (lexp,id),only_local_eff annot))) | LEXP_vector_range (lexp,exp1,exp2) -> - norm_lexp lexp >>= fun lexp -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_exp_to_name exp2 >>= fun exp2 -> - return (LEXP_aux (LEXP_vector_range (lexp,exp1,exp2),only_local_eff annot)) + n_lexp lexp (fun lexp -> + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + k (LEXP_aux (LEXP_vector_range (lexp,exp1,exp2),only_local_eff annot))))) | LEXP_field (lexp,id) -> - norm_lexp lexp >>= fun lexp -> - return (LEXP_aux (LEXP_field (lexp,id),only_local_eff annot)) - -and norm_exp (exp : 'a exp) : ('a exp,'a exp) cont = + n_lexp lexp (fun lexp -> + k (LEXP_aux (LEXP_field (lexp,id),only_local_eff annot))) + +and n_exp_term (new_return : bool) (exp : 'a exp) : 'a exp = + let (E_aux (_,annot)) = exp in + let exp = if new_return then E_aux (E_internal_return exp,annot) else exp in + n_exp exp (fun exp -> exp) + +and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp = let (E_aux (exp_aux,annot)) = exp in @@ -1052,124 +1034,125 @@ and norm_exp (exp : 'a exp) : ('a exp,'a exp) cont = match exp_aux with | E_block _ -> failwith "E_block should have been removed till now" | E_nondet _ -> failwith "E_nondet not supported" - | E_id id -> if value exp then return exp else letbind exp - | E_lit _ -> return exp + | E_id id -> k exp (* if value exp then return exp else letbind exp *) + | E_lit _ -> k exp | E_cast (typ,exp') -> - norm_exp_to_name exp' >>= fun exp' -> - return (rewrap_localeff (E_cast (typ,exp'))) + n_exp_name exp' (fun exp' -> + k (rewrap_localeff (E_cast (typ,exp')))) | E_app (id,exps) -> - norm_exp_to_nameL exps >>= fun exps -> - return (rewrap_localeff (E_app (id,exps))) + n_exp_nameL exps (fun exps -> + k (rewrap_localeff (E_app (id,exps)))) | E_app_infix (exp1,id,exp2) -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_exp_to_name exp2 >>= fun exp2 -> - return (rewrap_localeff (E_app_infix (exp1,id,exp2))) + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + k (rewrap_localeff (E_app_infix (exp1,id,exp2))))) | E_tuple exps -> - norm_exp_to_nameL exps >>= fun exps -> - return (rewrap_localeff (E_tuple exps)) + n_exp_nameL exps (fun exps -> + k (rewrap_localeff (E_tuple exps))) | E_if (exp1,exp2,exp3) -> - norm_exp_to_name exp1 >>= fun exp1 -> - let exp2 = norm_exp_to_term exp2 in - let exp3 = norm_exp_to_term exp3 in - return (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3))) + n_exp_name exp1 (fun exp1 -> + let (E_aux (_,annot2)) = exp2 in + let (E_aux (_,annot3)) = exp3 in + let new_return = effectful exp2 || effectful exp3 in + let exp2 = n_exp_term new_return exp2 in + let exp3 = n_exp_term new_return exp3 in + k (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3)))) | E_for (id,start,stop,by,dir,body) -> - norm_exp_to_name start >>= fun start -> - norm_exp_to_name stop >>= fun stop -> - norm_exp_to_name by >>= fun by -> - let body = norm_exp_to_term body in - return (rewrap_effs (geteffs body) (E_for (id,start,stop,by,dir,body))) + n_exp_name start (fun start -> + n_exp_name stop (fun stop -> + n_exp_name by (fun by -> + let body = n_exp_term (effectful body) body in + k (rewrap_effs (geteffs body) (E_for (id,start,stop,by,dir,body)))))) | E_vector exps -> - norm_exp_to_nameL exps >>= fun exps -> - return (rewrap_localeff (E_vector exps)) + n_exp_nameL exps (fun exps -> + k (rewrap_localeff (E_vector exps))) | E_vector_indexed (exps,opt_default) -> let (is,exps) = List.split exps in - norm_exp_to_nameL exps >>= fun exps -> - norm_opt_default opt_default >>= fun opt_default -> + n_exp_nameL exps (fun exps -> + n_opt_default opt_default (fun opt_default -> let exps = List.combine is exps in - return (rewrap_localeff (E_vector_indexed (exps,opt_default))) + k (rewrap_localeff (E_vector_indexed (exps,opt_default))))) | E_vector_access (exp1,exp2) -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_exp_to_name exp2 >>= fun exp2 -> - return (rewrap_localeff (E_vector_access (exp1,exp2))) + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + k (rewrap_localeff (E_vector_access (exp1,exp2))))) | E_vector_subrange (exp1,exp2,exp3) -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_exp_to_name exp2 >>= fun exp2 -> - norm_exp_to_name exp3 >>= fun exp3 -> - return (rewrap_localeff (E_vector_subrange (exp1,exp2,exp3))) + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + n_exp_name exp3 (fun exp3 -> + k (rewrap_localeff (E_vector_subrange (exp1,exp2,exp3)))))) | E_vector_update (exp1,exp2,exp3) -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_exp_to_name exp2 >>= fun exp2 -> - norm_exp_to_name exp3 >>= fun exp3 -> - return (rewrap_localeff (E_vector_update (exp1,exp2,exp3))) + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + n_exp_name exp3 (fun exp3 -> + k (rewrap_localeff (E_vector_update (exp1,exp2,exp3)))))) | E_vector_update_subrange (exp1,exp2,exp3,exp4) -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_exp_to_name exp2 >>= fun exp2 -> - norm_exp_to_name exp3 >>= fun exp3 -> - norm_exp_to_name exp4 >>= fun exp4 -> - return (rewrap_localeff (E_vector_update_subrange (exp1,exp2,exp3,exp4))) + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + n_exp_name exp3 (fun exp3 -> + n_exp_name exp4 (fun exp4 -> + k (rewrap_localeff (E_vector_update_subrange (exp1,exp2,exp3,exp4))))))) | E_vector_append (exp1,exp2) -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_exp_to_name exp2 >>= fun exp2 -> - return (rewrap_localeff (E_vector_append (exp1,exp2))) + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + k (rewrap_localeff (E_vector_append (exp1,exp2))))) | E_list exps -> - norm_exp_to_nameL exps >>= fun exps -> - return (rewrap_localeff (E_list exps)) + n_exp_nameL exps (fun exps -> + k (rewrap_localeff (E_list exps))) | E_cons (exp1,exp2) -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_exp_to_name exp2 >>= fun exp2 -> - return (rewrap_localeff (E_cons (exp1,exp2))) + n_exp_name exp1 (fun exp1 -> + n_exp_name exp2 (fun exp2 -> + k (rewrap_localeff (E_cons (exp1,exp2))))) | E_record fexps -> - norm_fexps fexps >>= fun fexps -> - return (rewrap_localeff (E_record fexps)) + n_fexps fexps (fun fexps -> + k (rewrap_localeff (E_record fexps))) | E_record_update (exp1,fexps) -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_fexps fexps >>= fun fexps -> - return (rewrap_localeff (E_record fexps)) + n_exp_name exp1 (fun exp1 -> + n_fexps fexps (fun fexps -> + k (rewrap_localeff (E_record fexps)))) | E_field (exp1,id) -> - norm_exp_to_name exp1 >>= fun exp1 -> - return (rewrap_localeff (E_field (exp1,id))) + n_exp_name exp1 (fun exp1 -> + k (rewrap_localeff (E_field (exp1,id)))) | E_case (exp1,pexps) -> - norm_exp_to_name exp1 >>= fun exp1 -> - norm_pexpL pexps >>= fun pexps -> + n_exp_name exp1 (fun exp1 -> + n_pexpL pexps (fun pexps -> let geteffs (Pat_aux (_,(_,Base (_,_,_,_,{effect = Eset effs},_)))) = effs in let effsum = {effect = Eset (dedup (List.flatten (List.map geteffs pexps)))} in - return (rewrap_effs effsum (E_case (exp1,pexps))) + k (rewrap_effs effsum (E_case (exp1,pexps))))) | E_let (lb,body) -> - norm_lb lb >>= fun lb -> + n_lb lb (fun lb -> (match lb with | LB_aux (LB_val_explicit (_,pat,exp'),annot') | LB_aux (LB_val_implicit (pat,exp'),annot') -> - let body = norm_exp_to_term body in if effectful exp' - then return (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',body))) - else return (rewrap_effs (geteffs body) (E_let (lb,body))) - ) + then (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',n_exp body k))) + else (rewrap_effs (geteffs body) (E_let (lb,n_exp body k))) + )) | E_assign (lexp,exp1) -> - norm_lexp lexp >>= fun lexp -> - norm_exp_to_name exp1 >>= fun exp1 -> - return (rewrap_localeff (E_assign (lexp,exp1))) - | E_exit exp' -> return (E_aux (E_exit (norm_exp_to_term exp'),annot)) + n_lexp lexp (fun lexp -> + n_exp_name exp1 (fun exp1 -> + k (rewrap_localeff (E_assign (lexp,exp1))))) + | E_exit exp' -> k (E_aux (E_exit (n_exp_term (effectful exp') exp'),annot)) | E_internal_cast (annot',exp') -> - norm_exp_to_name exp' >>= fun exp' -> - return (rewrap_localeff (E_internal_cast (annot',exp'))) - | E_internal_exp _ -> return exp - | E_internal_exp_user _ -> return exp + n_exp_name exp' (fun exp' -> + k (rewrap_localeff (E_internal_cast (annot',exp')))) + | E_internal_exp _ -> k exp + | E_internal_exp_user _ -> k exp | E_internal_let (lexp,exp1,exp2) -> (if effectful exp1 - then norm_exp_to_name exp1 - else norm_exp exp1) >>= fun exp1 -> - let exp2 = norm_exp_to_term exp2 in - return (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,exp2))) + then n_exp_name exp1 + else n_exp exp1) (fun exp1 -> + n_exp exp2 (fun exp2 -> + k (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,exp2))))) | E_internal_return exp1 -> - norm_exp_to_name exp1 >>= fun exp1 -> - return (rewrap_localeff (E_internal_return exp1)) + n_exp_name exp1 (fun exp1 -> + k (rewrap_localeff (E_internal_return exp1))) | E_internal_plet _ -> failwith "E_internal_plet should not be here yet" let rewrite_defs_a_normalise = let rewrite_exp _ _ e = - if effectful e then norm_exp_to_term (fold_exp remove_blocks_exp_alg e) - else e in + n_exp_term (effectful e) (fold_exp remove_blocks_exp_alg e) in rewrite_defs_base {rewrite_exp = rewrite_exp ; rewrite_pat = rewrite_pat |
