diff options
Diffstat (limited to 'src/rewriter.ml')
| -rw-r--r-- | src/rewriter.ml | 343 |
1 files changed, 179 insertions, 164 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index c4abde43..94ce67f4 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -852,14 +852,14 @@ let rewrite_defs_ocaml defs = let defs_lifted_assign = rewrite_defs_exp_lift_assign defs_vec_concat_removed in defs_lifted_assign + + let geteffs_annot = function - | (_,Base(_,_,_,_,effs,_)) -> effs + | (_,Base (_,_,_,_,effs,_)) -> effs | (_,NoTyp) -> failwith "no effect information" | _ -> failwith "a_normalise doesn't support Overload" let geteffs (E_aux (_,a)) = geteffs_annot a -let geteffslist_pexp (Pat_aux (_,a)) = - let {effect = Eset effs} = geteffs_annot a in effs let gettype (E_aux (_,(_,a))) = match a with @@ -905,9 +905,17 @@ let remove_blocks_exp_alg = { id_exp_alg with e_aux = e_aux } -let a_normalise_counter = ref 0 +type ('r,'a) cont = ('a -> 'r) -> 'r -let compose f g x = f (g x) +let return : 'a -> ('r,'a) cont = + fun a -> (fun f -> f a) + +let bind : ('r,'a) cont -> ('a -> ('r,'b) cont) -> ('r,'b) cont = + fun m f -> (fun k -> m (fun a -> (f a) k)) + +let (>>=) = bind + +let a_normalise_counter = ref 0 let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = let fresh_id () = @@ -922,14 +930,12 @@ let letbind (v : 'a exp) (body : 'a exp -> 'a exp) : 'a exp = let annot_pat = (Parse_ast.Unknown,simple_annot_efr (gettype v) (geteffs v)) in let annot_lb = annot_pat in - let annot_let = - (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v (body eid))) in + let annot_let = (Parse_ast.Unknown,simple_annot_efr {t = Tid "unit"} (eff_union v (body eid))) in + let pat = P_aux (P_id freshid,annot_pat) in - if effectful v then - E_aux (E_internal_plet (P_aux (P_id freshid,annot_pat),v,body eid),annot_let) - else - E_aux (E_let (LB_aux (LB_val_implicit (P_aux (P_id freshid,annot_pat),v), - annot_lb),body eid),annot_let) + if effectful v + then E_aux (E_internal_plet (pat,v,body eid),annot_let) + else E_aux (E_let (LB_aux (LB_val_implicit (pat,v),annot_lb),body eid),annot_let) let rec value ((E_aux (exp_aux,_)) as exp) = not (effectful exp) && @@ -944,201 +950,210 @@ let rec value ((E_aux (exp_aux,_)) as exp) = let only_local_eff (l,(Base ((t_params,t),tag,nexps,eff,effsum,bounds))) = (l,Base ((t_params,t),tag,nexps,eff,eff,bounds)) - -let rec norm_list : ('b -> ('b -> 'a exp) -> 'a exp) -> 'b list -> ('b list -> 'a exp) -> 'a exp = - fun normf l k -> +let rec norm_list : ('b -> ('annot exp,'b) cont) -> ('b list -> ('a exp,'b list) cont) = + fun normf l -> match l with - | [] -> k [] - | e :: es -> normf e (fun e -> norm_list normf es (fun es -> k (e :: es))) - -let rec norm_exp_to_name : 'a exp -> ('a exp -> 'a exp) -> 'a exp = - fun exp k -> norm_exp exp (fun exp -> if value exp then k exp else letbind exp k) + | [] -> return [] + | e :: es -> + normf e >>= fun e -> + norm_list normf es >>= fun es -> + return (e :: es) + +let rec norm_exp_to_name exp : ('a exp,'a exp) cont = + norm_exp exp >>= fun exp -> + if value exp then return exp else letbind exp -and norm_exp_to_nameL : ('a exp list -> ('a exp list -> 'a exp) -> 'a exp) = - fun exps k -> norm_list norm_exp_to_name exps k +and norm_exp_to_nameL (exps : 'a exp list) : ('a exp,'a exp list) cont = + norm_list norm_exp_to_name exps -and norm_fexp : 'a fexp -> ('a fexp -> 'a exp) -> 'a exp = - fun (FE_aux (FE_Fexp (id,exp),annot)) k -> - norm_exp_to_name exp (fun exp -> k (FE_aux (FE_Fexp (id,exp),annot))) +and norm_fexp (fexp : 'a fexp) : ('a exp,'a fexp) cont = + let (FE_aux (FE_Fexp (id,exp),annot)) = fexp in + norm_exp_to_name exp >>= fun exp -> + return (FE_aux (FE_Fexp (id,exp),annot)) -and norm_fexpL : 'a fexp list -> ('a fexp list -> 'a exp) -> 'a exp = - fun fexps k -> norm_list norm_fexp fexps k +and norm_fexpL (fexps : 'a fexp list) : ('a exp,'a fexp list) cont = + norm_list norm_fexp fexps -and norm_pexpL : 'a pexp list -> ('a pexp list -> 'a exp) -> 'a exp = - fun pexps k -> norm_list norm_pexp pexps k +and norm_pexpL (pexps : 'a pexp list) : ('a exp, 'a pexp list) cont = + norm_list norm_pexp pexps -and norm_exp_to_term : 'a exp -> 'a exp = - fun exp -> norm_exp exp (fun exp -> exp) +and norm_exp_to_term (exp : 'a exp) : 'a exp = + norm_exp exp (fun exp -> exp) -and norm_fexps : 'a fexps -> ('a fexps -> 'a exp) -> 'a exp = - fun (FES_aux (FES_Fexps (fexps,b),annot)) k -> - norm_fexpL fexps (fun fexps -> k (FES_aux (FES_Fexps (fexps,b),only_local_eff annot))) +and norm_fexps (fexps :'a fexps) : ('a exp,'a fexps) cont = + let (FES_aux (FES_Fexps (fexps_aux,b),annot)) = fexps in + norm_fexpL fexps_aux >>= fun fexps_aux -> + return (FES_aux (FES_Fexps (fexps_aux,b),only_local_eff annot)) -and norm_pexp : 'a pexp -> ('a pexp -> 'a exp) -> 'a exp = - fun (Pat_aux (Pat_exp (pat,exp),annot)) k -> - k (Pat_aux (Pat_exp (pat,norm_exp_to_term exp), annot)) +and norm_pexp (pexp : 'a pexp) : ('a exp,'a pexp) cont = + let (Pat_aux (Pat_exp (pat,exp),annot)) = pexp in + return (Pat_aux (Pat_exp (pat,norm_exp_to_term exp), annot)) -and norm_opt_default : 'a opt_default -> ('a opt_default -> 'a exp) -> 'a exp = - fun (Def_val_aux (opt_default,annot)) k -> +and norm_opt_default (opt_default : 'a opt_default) : ('a exp,'a opt_default) cont = + let (Def_val_aux (opt_default,annot)) = opt_default in match opt_default with - | Def_val_empty -> k (Def_val_aux (Def_val_empty,annot)) + | Def_val_empty -> return (Def_val_aux (Def_val_empty,annot)) | Def_val_dec exp' -> - norm_exp_to_name exp' - (fun exp' -> k (Def_val_aux (Def_val_dec exp',only_local_eff annot))) + norm_exp_to_name exp' >>= fun exp' -> + return (Def_val_aux (Def_val_dec exp',only_local_eff annot)) -and norm_lb : 'a letbind -> ('a letbind -> 'a exp) -> 'a exp = - fun (LB_aux (lb,annot)) k -> +and norm_lb (lb : 'a letbind) : ('a exp,'a letbind) cont = + let (LB_aux (lb,annot)) = lb in match lb with | LB_val_explicit (typ,pat,exp1) -> - norm_exp exp1 - (fun exp1 -> k (LB_aux (LB_val_explicit (typ,pat,exp1),only_local_eff annot))) + norm_exp exp1 >>= fun exp1 -> + return (LB_aux (LB_val_explicit (typ,pat,exp1),only_local_eff annot)) | LB_val_implicit (pat,exp1) -> - norm_exp exp1 - (fun exp1 -> k (LB_aux (LB_val_implicit (pat,exp1),only_local_eff annot))) - + norm_exp exp1 >>= fun exp1 -> + return (LB_aux (LB_val_implicit (pat,exp1),only_local_eff annot)) -and norm_lexp : 'a lexp -> ('a lexp -> 'a exp) -> 'a exp = - fun ((LEXP_aux (lexp_aux,annot)) as lexp) k -> +and norm_lexp (lexp : 'a lexp) : ('a exp,'a lexp) cont = + let (LEXP_aux (lexp_aux,annot)) = lexp in match lexp_aux with - | LEXP_id _ -> k lexp + | LEXP_id _ -> return lexp | LEXP_memory (id,es) -> - norm_exp_to_nameL es (fun es -> - k (LEXP_aux (LEXP_memory (id,es),only_local_eff annot))) - | LEXP_cast (typ,id) -> k (LEXP_aux (LEXP_cast (typ,id),only_local_eff annot)) + norm_exp_to_nameL es >>= fun es -> + return (LEXP_aux (LEXP_memory (id,es),only_local_eff annot)) + | LEXP_cast (typ,id) -> return (LEXP_aux (LEXP_cast (typ,id),only_local_eff annot)) | LEXP_vector (lexp,id) -> - norm_lexp lexp (fun lexp -> k (LEXP_aux (LEXP_vector (lexp,id),only_local_eff annot))) + norm_lexp lexp >>= fun lexp -> + return (LEXP_aux (LEXP_vector (lexp,id),only_local_eff annot)) | LEXP_vector_range (lexp,exp1,exp2) -> - norm_lexp lexp - (fun lexp -> - norm_exp_to_name exp1 - (fun exp1 -> - norm_exp_to_name exp2 - (fun exp2 -> k (LEXP_aux (LEXP_vector_range (lexp,exp1,exp2),only_local_eff annot))))) + norm_lexp lexp >>= fun lexp -> + norm_exp_to_name exp1 >>= fun exp1 -> + norm_exp_to_name exp2 >>= fun exp2 -> + return (LEXP_aux (LEXP_vector_range (lexp,exp1,exp2),only_local_eff annot)) | LEXP_field (lexp,id) -> - norm_lexp lexp (fun lexp -> k (LEXP_aux (LEXP_field (lexp,id),only_local_eff annot))) + norm_lexp lexp >>= fun lexp -> + return (LEXP_aux (LEXP_field (lexp,id),only_local_eff annot)) +and norm_exp (exp : 'a exp) : ('a exp,'a exp) cont = + + let (E_aux (exp_aux,annot)) = exp in -and norm_exp : 'a exp -> ('a exp -> 'a exp) -> 'a exp = - fun (E_aux (exp_aux,annot) as exp) k -> let rewrap_effs effsum exp_aux = (* explicitly give effect sum *) let (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds)) = annot in E_aux (exp_aux, (l,Base ((t_params,t),tag,nexps,eff,effsum,bounds))) in - let rewrap exp_aux = (* give exp_aux the local effect as the effect sum *) + let rewrap_localeff exp_aux = (* give exp_aux the local effect as the effect sum *) E_aux (exp_aux,only_local_eff annot) in match exp_aux with | E_block _ -> failwith "E_block should have been removed till now" | E_nondet _ -> failwith "E_nondet not supported" - | E_id id -> if value exp then k exp else letbind exp k - | E_lit _ -> k exp - | E_cast (typ,exp') -> norm_exp_to_name exp' (fun exp' -> k (rewrap (E_cast (typ,exp')))) - | E_app (id,exps) -> norm_exp_to_nameL exps (fun exps -> k (rewrap (E_app (id,exps)))) + | E_id id -> if value exp then return exp else letbind exp + | E_lit _ -> return exp + | E_cast (typ,exp') -> + norm_exp_to_name exp' >>= fun exp' -> + return (rewrap_localeff (E_cast (typ,exp'))) + | E_app (id,exps) -> + norm_exp_to_nameL exps >>= fun exps -> + return (rewrap_localeff (E_app (id,exps))) | E_app_infix (exp1,id,exp2) -> - norm_exp_to_name exp1 - (fun exp1 -> - norm_exp_to_name exp2 - (fun exp2 -> k (rewrap (E_app_infix (exp1,id,exp2))))) - | E_tuple exps -> norm_exp_to_nameL exps (fun exps -> k (rewrap (E_tuple exps))) + norm_exp_to_name exp1 >>= fun exp1 -> + norm_exp_to_name exp2 >>= fun exp2 -> + return (rewrap_localeff (E_app_infix (exp1,id,exp2))) + | E_tuple exps -> + norm_exp_to_nameL exps >>= fun exps -> + return (rewrap_localeff (E_tuple exps)) | E_if (exp1,exp2,exp3) -> - norm_exp_to_name exp1 - (fun exp1 -> - let exp2 = norm_exp_to_term exp2 in - let exp3 = norm_exp_to_term exp3 in - k (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3)))) + norm_exp_to_name exp1 >>= fun exp1 -> + let exp2 = norm_exp_to_term exp2 in + let exp3 = norm_exp_to_term exp3 in + return (rewrap_effs (eff_union exp2 exp3) (E_if (exp1,exp2,exp3))) | E_for (id,start,stop,by,dir,body) -> - norm_exp_to_name start - (fun start -> - norm_exp_to_name stop - (fun stop -> - norm_exp_to_name by - (fun by -> - let body = norm_exp_to_term body in - k (rewrap_effs (geteffs body) (E_for (id,start,stop,by,dir,body)))))) - | E_vector exps -> norm_exp_to_nameL exps (fun exps -> k (rewrap (E_vector exps))) + norm_exp_to_name start >>= fun start -> + norm_exp_to_name stop >>= fun stop -> + norm_exp_to_name by >>= fun by -> + let body = norm_exp_to_term body in + return (rewrap_effs (geteffs body) (E_for (id,start,stop,by,dir,body))) + | E_vector exps -> + norm_exp_to_nameL exps >>= fun exps -> + return (rewrap_localeff (E_vector exps)) | E_vector_indexed (exps,opt_default) -> - let (is,exps) = List.split exps in - norm_exp_to_nameL exps - (fun exps -> - norm_opt_default opt_default - (fun opt_default -> rewrap (E_vector_indexed (List.combine is exps,opt_default)))) + let (is,exps) = List.split exps in + norm_exp_to_nameL exps >>= fun exps -> + norm_opt_default opt_default >>= fun opt_default -> + let exps = List.combine is exps in + return (rewrap_localeff (E_vector_indexed (exps,opt_default))) | E_vector_access (exp1,exp2) -> - norm_exp_to_name exp1 - (fun exp1 -> norm_exp_to_name exp2 (fun exp2 -> k (rewrap (E_vector_access (exp1,exp2))))) + norm_exp_to_name exp1 >>= fun exp1 -> + norm_exp_to_name exp2 >>= fun exp2 -> + return (rewrap_localeff (E_vector_access (exp1,exp2))) | E_vector_subrange (exp1,exp2,exp3) -> - norm_exp_to_name exp1 - (fun exp1 -> - norm_exp_to_name exp2 - (fun exp2 -> - norm_exp_to_name exp3 (fun exp3 -> k (rewrap (E_vector_subrange (exp1,exp2,exp3)))))) + norm_exp_to_name exp1 >>= fun exp1 -> + norm_exp_to_name exp2 >>= fun exp2 -> + norm_exp_to_name exp3 >>= fun exp3 -> + return (rewrap_localeff (E_vector_subrange (exp1,exp2,exp3))) | E_vector_update (exp1,exp2,exp3) -> - norm_exp_to_name exp1 - (fun exp1 -> - norm_exp_to_name exp2 - (fun exp2 -> - norm_exp_to_name exp3 (fun exp3 -> k (rewrap (E_vector_update (exp1,exp2,exp3)))))) + norm_exp_to_name exp1 >>= fun exp1 -> + norm_exp_to_name exp2 >>= fun exp2 -> + norm_exp_to_name exp3 >>= fun exp3 -> + return (rewrap_localeff (E_vector_update (exp1,exp2,exp3))) | E_vector_update_subrange (exp1,exp2,exp3,exp4) -> - norm_exp_to_name exp1 - (fun exp1 -> - norm_exp_to_name exp2 - (fun exp2 -> - norm_exp_to_name exp3 - (fun exp3 -> - norm_exp_to_name exp4 - (fun exp4 -> k (rewrap (E_vector_update_subrange (exp1,exp2,exp3,exp4))))))) + norm_exp_to_name exp1 >>= fun exp1 -> + norm_exp_to_name exp2 >>= fun exp2 -> + norm_exp_to_name exp3 >>= fun exp3 -> + norm_exp_to_name exp4 >>= fun exp4 -> + return (rewrap_localeff (E_vector_update_subrange (exp1,exp2,exp3,exp4))) | E_vector_append (exp1,exp2) -> - norm_exp_to_name exp1 - (fun exp1 -> norm_exp_to_name exp2 (fun exp2 -> k (rewrap (E_vector_append (exp1,exp2))))) - | E_list exps -> norm_exp_to_nameL exps (fun exps -> k (rewrap (E_list exps))) - | E_cons (exp1,exp2) -> - norm_exp_to_name exp1 - (fun exp1 -> norm_exp_to_name exp2 (fun exp2 -> k (rewrap (E_cons (exp1,exp2))))) - | E_record fexps -> norm_fexps fexps (fun fexps -> k (rewrap (E_record fexps))) - | E_record_update (exp1,fexps) -> - norm_exp_to_name exp1 (fun exp1 -> norm_fexps fexps (fun fexps -> k (rewrap (E_record fexps)))) - | E_field (exp1,id) -> norm_exp_to_name exp1 (fun exp1 -> k (rewrap (E_field (exp1,id)))) - | E_case (exp1,pexps) -> - norm_exp_to_name exp1 (fun exp1 -> - norm_pexpL pexps - (fun pexps -> - let effsum = List.fold_left - (fun effs pat -> dedup (effs @ geteffslist_pexp pat)) [] pexps in - let effsum = {effect = Eset effsum} in - k (rewrap_effs effsum (E_case (exp1,pexps))))) - | E_let (lb,body) -> - norm_lb lb - (fun lb -> - match lb with - | LB_aux (LB_val_explicit (typ,pat,exp'),annot') -> - if effectful_effs (geteffs_annot annot') - then k (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',norm_exp_to_term body))) - else k (rewrap_effs (geteffs body) (E_let (lb,norm_exp_to_term body))) - | LB_aux (LB_val_implicit (pat,exp'),annot') -> - if effectful_effs (geteffs_annot annot') - then k (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',norm_exp_to_term body))) - else k (rewrap_effs (geteffs body) (E_let (lb,norm_exp_to_term body))) - ) - | E_assign (lexp,exp1) -> - norm_lexp lexp (fun lexp -> - norm_exp_to_name exp1 (fun exp1 -> k (rewrap (E_assign (lexp,exp1))))) - | E_exit exp' -> k (E_aux (E_exit (norm_exp_to_term exp'),annot)) - | E_internal_cast (annot',exp') -> - norm_exp_to_name exp' (fun exp' -> k (rewrap (E_internal_cast (annot',exp')))) - | E_internal_exp annot' -> k (rewrap (E_internal_exp annot')) - | E_internal_exp_user (annot1,annot2) -> k (rewrap (E_internal_exp_user (annot1,annot2))) - | E_internal_let (lexp,exp1,exp2) -> - (if effectful exp1 then norm_exp_to_name exp1 else norm_exp exp1) - (fun exp1 -> k (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,norm_exp_to_term exp2)))) - | E_internal_plet _ -> failwith "E_internal_plet should not be here yet" - - -let rec a_normalise exp = - let exp = fold_exp remove_blocks_exp_alg exp in - norm_exp_to_term exp - + norm_exp_to_name exp1 >>= fun exp1 -> + norm_exp_to_name exp2 >>= fun exp2 -> + return (rewrap_localeff (E_vector_append (exp1,exp2))) + | E_list exps -> + norm_exp_to_nameL exps >>= fun exps -> + return (rewrap_localeff (E_list exps)) + | E_cons (exp1,exp2) -> + norm_exp_to_name exp1 >>= fun exp1 -> + norm_exp_to_name exp2 >>= fun exp2 -> + return (rewrap_localeff (E_cons (exp1,exp2))) + | E_record fexps -> + norm_fexps fexps >>= fun fexps -> + return (rewrap_localeff (E_record fexps)) + | E_record_update (exp1,fexps) -> + norm_exp_to_name exp1 >>= fun exp1 -> + norm_fexps fexps >>= fun fexps -> + return (rewrap_localeff (E_record fexps)) + | E_field (exp1,id) -> + norm_exp_to_name exp1 >>= fun exp1 -> + return (rewrap_localeff (E_field (exp1,id))) + | E_case (exp1,pexps) -> + norm_exp_to_name exp1 >>= fun exp1 -> + norm_pexpL pexps >>= fun pexps -> + let geteffs (Pat_aux (_,(_,Base (_,_,_,_,{effect = Eset effs},_)))) = effs in + let effsum = {effect = Eset (dedup (List.flatten (List.map geteffs pexps)))} in + return (rewrap_effs effsum (E_case (exp1,pexps))) + | E_let (lb,body) -> + norm_lb lb >>= fun lb -> + (match lb with + | LB_aux (LB_val_explicit (_,pat,exp'),annot') + | LB_aux (LB_val_implicit (pat,exp'),annot') -> + let body = norm_exp_to_term body in + if effectful exp' + then return (rewrap_effs (eff_union exp' body) (E_internal_plet (pat,exp',body))) + else return (rewrap_effs (geteffs body) (E_let (lb,body))) + ) + | E_assign (lexp,exp1) -> + norm_lexp lexp >>= fun lexp -> + norm_exp_to_name exp1 >>= fun exp1 -> + return (rewrap_localeff (E_assign (lexp,exp1))) + | E_exit exp' -> return (E_aux (E_exit (norm_exp_to_term exp'),annot)) + | E_internal_cast (annot',exp') -> + norm_exp_to_name exp' >>= fun exp' -> + return (rewrap_localeff (E_internal_cast (annot',exp'))) + | E_internal_exp _ -> return exp + | E_internal_exp_user _ -> return exp + | E_internal_let (lexp,exp1,exp2) -> + (if effectful exp1 + then norm_exp_to_name exp1 + else norm_exp exp1) >>= fun exp1 -> + let exp2 = norm_exp_to_term exp2 in + return (rewrap_effs (geteffs exp2) (E_internal_let (lexp,exp1,exp2))) + | E_internal_plet _ -> failwith "E_internal_plet should not be here yet" + let rewrite_defs_a_normalise (Defs defs) = rewrite_defs_base - {rewrite_exp = (fun _ _ e -> a_normalise e); + {rewrite_exp = (fun _ _ e -> norm_exp_to_term (fold_exp remove_blocks_exp_alg e)); rewrite_pat = rewrite_pat; rewrite_let = rewrite_let; rewrite_lexp = rewrite_lexp; |
