summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorChristopher Pulte2015-10-20 14:28:24 +0100
committerChristopher Pulte2015-10-20 14:28:24 +0100
commit117e58ac3da5d79dab16988b693cdd0908c0bb48 (patch)
treee75c66f2d6ede16924a02f555e34159f6a197f4a /src
parent602adb432b158efa403959454328bc58bddca61b (diff)
fix a-normalisation bug
Diffstat (limited to 'src')
-rw-r--r--src/pretty_print.ml3
-rw-r--r--src/rewriter.ml287
2 files changed, 137 insertions, 153 deletions
diff --git a/src/pretty_print.ml b/src/pretty_print.ml
index 3ffc67b7..9b5173c3 100644
--- a/src/pretty_print.ml
+++ b/src/pretty_print.ml
@@ -1340,9 +1340,10 @@ let doc_exp_ocaml, doc_let_ocaml =
string "if" ^^ space ^^ string "to_bool" ^^ parens (exp c) ^/^
string "then" ^^ space ^^ (exp t)
| E_if(c,t,e) ->
+ parens (
string "if" ^^ space ^^ string "to_bool" ^^ parens (exp c) ^/^
string "then" ^^ space ^^ group (exp t) ^/^
- string "else" ^^ space ^^ group (exp e)
+ string "else" ^^ space ^^ group (exp e))
| E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) ->
let var= doc_id_ocaml id in
let (compare,next) = if order = Ord_inc then string "<=",string "+" else string ">=",string "-" in
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 9aa5063d..d7bb68f4 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -913,37 +913,30 @@ let remove_blocks_exp_alg =
{ id_exp_alg with e_aux = e_aux }
-type ('r,'a) cont = ('a -> 'r) -> 'r
-
-let return : 'a -> ('r,'a) cont =
- fun a -> (fun f -> f a)
-
-let bind : ('r,'a) cont -> ('a -> ('r,'b) cont) -> ('r,'b) cont =
- fun m f -> (fun k -> m (fun a -> (f a) k))
-
-let (>>=) = bind
-
let a_normalise_counter = ref 0
let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp =
+ (* body is a function : E_id variable -> actual body *)
+
let fresh_id () =
let current = !a_normalise_counter in
let () = a_normalise_counter := (current + 1) in
Id_aux (Id ("__w" ^ string_of_int current), Parse_ast.Unknown) in
- (* body is a function : E_id variable -> actual body *)
let freshid = fresh_id () in
let annot_var = (Parse_ast.Unknown,simple_annot (gettype v)) in
let eid = E_aux (E_id freshid, annot_var) in
+
+ let body = body eid in
let annot_pat = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in
let annot_lb = annot_pat in
- let annot_let = (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v (body eid))) in
+ let annot_let = (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v body)) in
let pat = P_aux (P_id freshid,annot_pat) in
if effectful v
- then E_aux (E_internal_plet (pat,v,body eid),annot_let)
- else E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body eid),annot_let)
+ then E_aux (E_internal_plet (pat,v,body),annot_let)
+ else E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body),annot_let)
let rec value ((E_aux (exp_aux,_)) as exp) =
not (effectful exp) &&
@@ -958,87 +951,76 @@ let rec value ((E_aux (exp_aux,_)) as exp) =
let only_local_eff (l,(Base ((t_params,t),tag,nexps,eff,effsum,bounds))) =
(l,Base ((t_params,t),tag,nexps,eff,eff,bounds))
-let rec norm_list : ('b -> ('annot exp,'b) cont) -> ('b list -> ('a exp,'b list) cont) =
- fun normf l ->
+let rec mapCont (f : 'b -> ('b -> 'a exp) -> 'a exp) (l : 'b list) (k : 'b list -> 'a exp) : 'a exp =
match l with
- | [] -> return []
- | e :: es ->
- normf e >>= fun e ->
- norm_list normf es >>= fun es ->
- return (e :: es)
-
-let rec norm_exp_to_name exp : ('a exp,'a exp) cont =
- norm_exp exp >>= fun exp ->
- if value exp then return exp else letbind exp
+ | [] -> k []
+ | exp :: exps -> f exp (fun exp -> mapCont f exps (fun exps -> k (exp :: exps)))
+
+let rec n_exp_name (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
+ n_exp exp (fun exp -> if value exp then k exp else letbind exp k)
-and norm_exp_to_nameL (exps : 'a exp list) : ('a exp,'a exp list) cont =
- norm_list norm_exp_to_name exps
+and n_exp_nameL (exps : 'a exp list) (k : 'a exp list -> 'a exp) : 'a exp =
+ mapCont n_exp_name exps k
-and norm_fexp (fexp : 'a fexp) : ('a exp,'a fexp) cont =
+and n_fexp (fexp : 'a fexp) (k : 'a fexp -> 'a exp) : 'a exp =
let (FE_aux (FE_Fexp (id,exp),annot)) = fexp in
- norm_exp_to_name exp >>= fun exp ->
- return (FE_aux (FE_Fexp (id,exp),annot))
+ n_exp_name exp (fun exp -> k (FE_aux (FE_Fexp (id,exp),annot)))
-and norm_fexpL (fexps : 'a fexp list) : ('a exp,'a fexp list) cont =
- norm_list norm_fexp fexps
-
-and norm_pexpL (pexps : 'a pexp list) : ('a exp, 'a pexp list) cont =
- norm_list norm_pexp pexps
+and n_fexpL (fexps : 'a fexp list) (k : 'a fexp list -> 'a exp) : 'a exp =
+ mapCont n_fexp fexps k
+
+and n_pexp (new_return : bool) (pexp : 'a pexp) (k : 'a pexp -> 'a exp) : 'a exp =
+ let (Pat_aux (Pat_exp (pat,exp),annot)) = pexp in
+ k (Pat_aux (Pat_exp (pat,n_exp_term new_return exp), annot))
-and norm_exp_to_term (exp : 'a exp) : 'a exp =
- let (E_aux (_,annot)) = exp in
- let exp =
- if effectful exp then E_aux (E_internal_return exp,annot) else exp in
- norm_exp exp (fun exp -> exp)
+and n_pexpL (pexps : 'a pexp list) (k : 'a pexp list -> 'a exp) : 'a exp =
+ let geteffs (Pat_aux (_,(_,Base (_,_,_,_,{effect = Eset effs},_)))) = effs in
+ let effs = {effect = Eset (List.flatten (List.map geteffs pexps))} in
+ mapCont (n_pexp (effectful_effs effs)) pexps k
-and norm_fexps (fexps :'a fexps) : ('a exp,'a fexps) cont =
+and n_fexps (fexps : 'a fexps) (k : 'a fexps -> 'a exp) : 'a exp =
let (FES_aux (FES_Fexps (fexps_aux,b),annot)) = fexps in
- norm_fexpL fexps_aux >>= fun fexps_aux ->
- return (FES_aux (FES_Fexps (fexps_aux,b),only_local_eff annot))
-
-and norm_pexp (pexp : 'a pexp) : ('a exp,'a pexp) cont =
- let (Pat_aux (Pat_exp (pat,exp),annot)) = pexp in
- return (Pat_aux (Pat_exp (pat,norm_exp_to_term exp), annot))
+ n_fexpL fexps_aux (fun fexps_aux -> k (FES_aux (FES_Fexps (fexps_aux,b),only_local_eff annot)))
-and norm_opt_default (opt_default : 'a opt_default) : ('a exp,'a opt_default) cont =
+and n_opt_default (opt_default : 'a opt_default) (k : 'a opt_default -> 'a exp) : 'a exp =
let (Def_val_aux (opt_default,annot)) = opt_default in
match opt_default with
- | Def_val_empty -> return (Def_val_aux (Def_val_empty,annot))
- | Def_val_dec exp' ->
- norm_exp_to_name exp' >>= fun exp' ->
- return (Def_val_aux (Def_val_dec exp',only_local_eff annot))
+ | Def_val_empty -> k (Def_val_aux (Def_val_empty,annot))
+ | Def_val_dec exp ->
+ n_exp_name exp (fun exp -> k (Def_val_aux (Def_val_dec exp,only_local_eff annot)))
-and norm_lb (lb : 'a letbind) : ('a exp,'a letbind) cont =
+and n_lb (lb : 'a letbind) (k : 'a letbind -> 'a exp) : 'a exp =
let (LB_aux (lb,annot)) = lb in
match lb with
| LB_val_explicit (typ,pat,exp1) ->
- norm_exp exp1 >>= fun exp1 ->
- return (LB_aux (LB_val_explicit (typ,pat,exp1),only_local_eff annot))
+ n_exp exp1 (fun exp1 -> k (LB_aux (LB_val_explicit (typ,pat,exp1),only_local_eff annot)))
| LB_val_implicit (pat,exp1) ->
- norm_exp exp1 >>= fun exp1 ->
- return (LB_aux (LB_val_implicit (pat,exp1),only_local_eff annot))
+ n_exp exp1 (fun exp1 -> k (LB_aux (LB_val_implicit (pat,exp1),only_local_eff annot)))
-and norm_lexp (lexp : 'a lexp) : ('a exp,'a lexp) cont =
+and n_lexp (lexp : 'a lexp) (k : 'a lexp -> 'a exp) : 'a exp =
let (LEXP_aux (lexp_aux,annot)) = lexp in
match lexp_aux with
- | LEXP_id _ -> return lexp
+ | LEXP_id _ -> k lexp
| LEXP_memory (id,es) ->
- norm_exp_to_nameL es >>= fun es ->
- return (LEXP_aux (LEXP_memory (id,es),only_local_eff annot))
- | LEXP_cast (typ,id) -> return (LEXP_aux (LEXP_cast (typ,id),only_local_eff annot))
+ n_exp_nameL es (fun es -> k (LEXP_aux (LEXP_memory (id,es),only_local_eff annot)))
+ | LEXP_cast (typ,id) -> k (LEXP_aux (LEXP_cast (typ,id),only_local_eff annot))
| LEXP_vector (lexp,id) ->
- norm_lexp lexp >>= fun lexp ->
- return (LEXP_aux (LEXP_vector (lexp,id),only_local_eff annot))
+ n_lexp lexp (fun lexp -> k (LEXP_aux (LEXP_vector (lexp,id),only_local_eff annot)))
| LEXP_vector_range (lexp,exp1,exp2) ->
- norm_lexp lexp >>= fun lexp ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_exp_to_name exp2 >>= fun exp2 ->
- return (LEXP_aux (LEXP_vector_range (lexp,exp1,exp2),only_local_eff annot))
+ n_lexp lexp (fun lexp ->
+ n_exp_name exp1 (fun exp1 ->
+ n_exp_name exp2 (fun exp2 ->
+ k (LEXP_aux (LEXP_vector_range (lexp,exp1,exp2),only_local_eff annot)))))
| LEXP_field (lexp,id) ->
- norm_lexp lexp >>= fun lexp ->
- return (LEXP_aux (LEXP_field (lexp,id),only_local_eff annot))
-
-and norm_exp (exp : 'a exp) : ('a exp,'a exp) cont =
+ n_lexp lexp (fun lexp ->
+ k (LEXP_aux (LEXP_field (lexp,id),only_local_eff annot)))
+
+and n_exp_term (new_return : bool) (exp : 'a exp) : 'a exp =
+ let (E_aux (_,annot)) = exp in
+ let exp = if new_return then E_aux (E_internal_return exp,annot) else exp in
+ n_exp exp (fun exp -> exp)
+
+and n_exp (exp : 'a exp) (k : 'a exp -> 'a exp) : 'a exp =
let (E_aux (exp_aux,annot)) = exp in
@@ -1052,124 +1034,125 @@ and norm_exp (exp : 'a exp) : ('a exp,'a exp) cont =
match exp_aux with
| E_block _ -> failwith "E_block should have been removed till now"
| E_nondet _ -> failwith "E_nondet not supported"
- | E_id id -> if value exp then return exp else letbind exp
- | E_lit _ -> return exp
+ | E_id id -> k exp (* if value exp then return exp else letbind exp *)
+ | E_lit _ -> k exp
| E_cast (typ,exp') ->
- norm_exp_to_name exp' >>= fun exp' ->
- return (rewrap_localeff (E_cast (typ,exp')))
+ n_exp_name exp' (fun exp' ->
+ k (rewrap_localeff (E_cast (typ,exp'))))
| E_app (id,exps) ->
- norm_exp_to_nameL exps >>= fun exps ->
- return (rewrap_localeff (E_app (id,exps)))
+ n_exp_nameL exps (fun exps ->
+ k (rewrap_localeff (E_app (id,exps))))
| E_app_infix (exp1,id,exp2) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_exp_to_name exp2 >>= fun exp2 ->
- return (rewrap_localeff (E_app_infix (exp1,id,exp2)))
+ n_exp_name exp1 (fun exp1 ->
+ n_exp_name exp2 (fun exp2 ->
+ k (rewrap_localeff (E_app_infix (exp1,id,exp2)))))
| E_tuple exps ->
- norm_exp_to_nameL exps >>= fun exps ->
- return (rewrap_localeff (E_tuple exps))
+ n_exp_nameL exps (fun exps ->
+ k (rewrap_localeff (E_tuple exps)))
| E_if (exp1,exp2,exp3) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- let exp2 = norm_exp_to_term exp2 in
- let exp3 = norm_exp_to_term exp3 in
- return (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3)))
+ n_exp_name exp1 (fun exp1 ->
+ let (E_aux (_,annot2)) = exp2 in
+ let (E_aux (_,annot3)) = exp3 in
+ let new_return = effectful exp2 || effectful exp3 in
+ let exp2 = n_exp_term new_return exp2 in
+ let exp3 = n_exp_term new_return exp3 in
+ k (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3))))
| E_for (id,start,stop,by,dir,body) ->
- norm_exp_to_name start >>= fun start ->
- norm_exp_to_name stop >>= fun stop ->
- norm_exp_to_name by >>= fun by ->
- let body = norm_exp_to_term body in
- return (rewrap_effs (geteffs body) (E_for (id,start,stop,by,dir,body)))
+ n_exp_name start (fun start ->
+ n_exp_name stop (fun stop ->
+ n_exp_name by (fun by ->
+ let body = n_exp_term (effectful body) body in
+ k (rewrap_effs (geteffs body) (E_for (id,start,stop,by,dir,body))))))
| E_vector exps ->
- norm_exp_to_nameL exps >>= fun exps ->
- return (rewrap_localeff (E_vector exps))
+ n_exp_nameL exps (fun exps ->
+ k (rewrap_localeff (E_vector exps)))
| E_vector_indexed (exps,opt_default) ->
let (is,exps) = List.split exps in
- norm_exp_to_nameL exps >>= fun exps ->
- norm_opt_default opt_default >>= fun opt_default ->
+ n_exp_nameL exps (fun exps ->
+ n_opt_default opt_default (fun opt_default ->
let exps = List.combine is exps in
- return (rewrap_localeff (E_vector_indexed (exps,opt_default)))
+ k (rewrap_localeff (E_vector_indexed (exps,opt_default)))))
| E_vector_access (exp1,exp2) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_exp_to_name exp2 >>= fun exp2 ->
- return (rewrap_localeff (E_vector_access (exp1,exp2)))
+ n_exp_name exp1 (fun exp1 ->
+ n_exp_name exp2 (fun exp2 ->
+ k (rewrap_localeff (E_vector_access (exp1,exp2)))))
| E_vector_subrange (exp1,exp2,exp3) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_exp_to_name exp2 >>= fun exp2 ->
- norm_exp_to_name exp3 >>= fun exp3 ->
- return (rewrap_localeff (E_vector_subrange (exp1,exp2,exp3)))
+ n_exp_name exp1 (fun exp1 ->
+ n_exp_name exp2 (fun exp2 ->
+ n_exp_name exp3 (fun exp3 ->
+ k (rewrap_localeff (E_vector_subrange (exp1,exp2,exp3))))))
| E_vector_update (exp1,exp2,exp3) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_exp_to_name exp2 >>= fun exp2 ->
- norm_exp_to_name exp3 >>= fun exp3 ->
- return (rewrap_localeff (E_vector_update (exp1,exp2,exp3)))
+ n_exp_name exp1 (fun exp1 ->
+ n_exp_name exp2 (fun exp2 ->
+ n_exp_name exp3 (fun exp3 ->
+ k (rewrap_localeff (E_vector_update (exp1,exp2,exp3))))))
| E_vector_update_subrange (exp1,exp2,exp3,exp4) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_exp_to_name exp2 >>= fun exp2 ->
- norm_exp_to_name exp3 >>= fun exp3 ->
- norm_exp_to_name exp4 >>= fun exp4 ->
- return (rewrap_localeff (E_vector_update_subrange (exp1,exp2,exp3,exp4)))
+ n_exp_name exp1 (fun exp1 ->
+ n_exp_name exp2 (fun exp2 ->
+ n_exp_name exp3 (fun exp3 ->
+ n_exp_name exp4 (fun exp4 ->
+ k (rewrap_localeff (E_vector_update_subrange (exp1,exp2,exp3,exp4)))))))
| E_vector_append (exp1,exp2) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_exp_to_name exp2 >>= fun exp2 ->
- return (rewrap_localeff (E_vector_append (exp1,exp2)))
+ n_exp_name exp1 (fun exp1 ->
+ n_exp_name exp2 (fun exp2 ->
+ k (rewrap_localeff (E_vector_append (exp1,exp2)))))
| E_list exps ->
- norm_exp_to_nameL exps >>= fun exps ->
- return (rewrap_localeff (E_list exps))
+ n_exp_nameL exps (fun exps ->
+ k (rewrap_localeff (E_list exps)))
| E_cons (exp1,exp2) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_exp_to_name exp2 >>= fun exp2 ->
- return (rewrap_localeff (E_cons (exp1,exp2)))
+ n_exp_name exp1 (fun exp1 ->
+ n_exp_name exp2 (fun exp2 ->
+ k (rewrap_localeff (E_cons (exp1,exp2)))))
| E_record fexps ->
- norm_fexps fexps >>= fun fexps ->
- return (rewrap_localeff (E_record fexps))
+ n_fexps fexps (fun fexps ->
+ k (rewrap_localeff (E_record fexps)))
| E_record_update (exp1,fexps) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_fexps fexps >>= fun fexps ->
- return (rewrap_localeff (E_record fexps))
+ n_exp_name exp1 (fun exp1 ->
+ n_fexps fexps (fun fexps ->
+ k (rewrap_localeff (E_record fexps))))
| E_field (exp1,id) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- return (rewrap_localeff (E_field (exp1,id)))
+ n_exp_name exp1 (fun exp1 ->
+ k (rewrap_localeff (E_field (exp1,id))))
| E_case (exp1,pexps) ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- norm_pexpL pexps >>= fun pexps ->
+ n_exp_name exp1 (fun exp1 ->
+ n_pexpL pexps (fun pexps ->
let geteffs (Pat_aux (_,(_,Base (_,_,_,_,{effect = Eset effs},_)))) = effs in
let effsum = {effect = Eset (dedup (List.flatten (List.map geteffs pexps)))} in
- return (rewrap_effs effsum (E_case (exp1,pexps)))
+ k (rewrap_effs effsum (E_case (exp1,pexps)))))
| E_let (lb,body) ->
- norm_lb lb >>= fun lb ->
+ n_lb lb (fun lb ->
(match lb with
| LB_aux (LB_val_explicit (_,pat,exp'),annot')
| LB_aux (LB_val_implicit (pat,exp'),annot') ->
- let body = norm_exp_to_term body in
if effectful exp'
- then return (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',body)))
- else return (rewrap_effs (geteffs body) (E_let (lb,body)))
- )
+ then (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',n_exp body k)))
+ else (rewrap_effs (geteffs body) (E_let (lb,n_exp body k)))
+ ))
| E_assign (lexp,exp1) ->
- norm_lexp lexp >>= fun lexp ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- return (rewrap_localeff (E_assign (lexp,exp1)))
- | E_exit exp' -> return (E_aux (E_exit (norm_exp_to_term exp'),annot))
+ n_lexp lexp (fun lexp ->
+ n_exp_name exp1 (fun exp1 ->
+ k (rewrap_localeff (E_assign (lexp,exp1)))))
+ | E_exit exp' -> k (E_aux (E_exit (n_exp_term (effectful exp') exp'),annot))
| E_internal_cast (annot',exp') ->
- norm_exp_to_name exp' >>= fun exp' ->
- return (rewrap_localeff (E_internal_cast (annot',exp')))
- | E_internal_exp _ -> return exp
- | E_internal_exp_user _ -> return exp
+ n_exp_name exp' (fun exp' ->
+ k (rewrap_localeff (E_internal_cast (annot',exp'))))
+ | E_internal_exp _ -> k exp
+ | E_internal_exp_user _ -> k exp
| E_internal_let (lexp,exp1,exp2) ->
(if effectful exp1
- then norm_exp_to_name exp1
- else norm_exp exp1) >>= fun exp1 ->
- let exp2 = norm_exp_to_term exp2 in
- return (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,exp2)))
+ then n_exp_name exp1
+ else n_exp exp1) (fun exp1 ->
+ n_exp exp2 (fun exp2 ->
+ k (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,exp2)))))
| E_internal_return exp1 ->
- norm_exp_to_name exp1 >>= fun exp1 ->
- return (rewrap_localeff (E_internal_return exp1))
+ n_exp_name exp1 (fun exp1 ->
+ k (rewrap_localeff (E_internal_return exp1)))
| E_internal_plet _ -> failwith "E_internal_plet should not be here yet"
let rewrite_defs_a_normalise =
let rewrite_exp _ _ e =
- if effectful e then norm_exp_to_term (fold_exp remove_blocks_exp_alg e)
- else e in
+ n_exp_term (effectful e) (fold_exp remove_blocks_exp_alg e) in
rewrite_defs_base
{rewrite_exp = rewrite_exp
; rewrite_pat = rewrite_pat