diff options
| author | Thomas Bauereiss | 2018-12-18 15:16:36 +0000 |
|---|---|---|
| committer | Thomas Bauereiss | 2018-12-18 15:16:36 +0000 |
| commit | 1766bf5e3628b5c45290a3353bec05823661b9d3 (patch) | |
| tree | cae2f596d135074399cd304bb8e3dca1330a2aa8 /src/rewriter.ml | |
| parent | df0e02bc0c8259962f25d4c175fa950391695ab6 (diff) | |
| parent | 07a332c856b3ee9fe26a9cd47ea6005f9d579810 (diff) | |
Merge branch 'sail2' into monads
Diffstat (limited to 'src/rewriter.ml')
| -rw-r--r-- | src/rewriter.ml | 73 |
1 files changed, 35 insertions, 38 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml index a7505ca7..a70f6fab 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -64,11 +64,10 @@ type 'a rewriters = { rewrite_defs : 'a rewriters -> 'a defs -> 'a defs; } - let effect_of_fpat (FP_aux (_,(_,a))) = effect_of_annot a let effect_of_lexp (LEXP_aux (_,(_,a))) = effect_of_annot a let effect_of_fexp (FE_aux (_,(_,a))) = effect_of_annot a -let effect_of_fexps (FES_aux (FES_Fexps (fexps,_),_)) = +let effect_of_fexps fexps = List.fold_left union_effects no_effect (List.map effect_of_fexp fexps) let effect_of_opt_default (Def_val_aux (_,(_,a))) = effect_of_annot a (* The typechecker does not seem to annotate pexps themselves *) @@ -95,7 +94,7 @@ let lookup_generated_kid env kid = let generated_kids typ = KidSet.filter is_kid_generated (tyvars_of_typ typ) let resolve_generated_kids env typ = - let subst_kid kid typ = typ_subst_kid kid (lookup_generated_kid env kid) typ in + let subst_kid kid typ = subst_kid typ_subst kid (lookup_generated_kid env kid) typ in KidSet.fold subst_kid (generated_kids typ) typ let rec remove_p_typ = function @@ -103,7 +102,7 @@ let rec remove_p_typ = function | pat -> pat let add_p_typ typ (P_aux (paux, annot) as pat) = - let typ' = resolve_generated_kids (pat_env_of pat) typ in + let typ' = resolve_generated_kids (env_of_pat pat) typ in if KidSet.is_empty (generated_kids typ') then P_aux (P_typ (typ', remove_p_typ pat), annot) else pat @@ -295,16 +294,14 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot)) as orig_exp) = | E_vector_append (v1,v2) -> rewrap (E_vector_append (rewrite v1,rewrite v2)) | E_list exps -> rewrap (E_list (List.map rewrite exps)) | E_cons(h,t) -> rewrap (E_cons (rewrite h,rewrite t)) - | E_record (FES_aux (FES_Fexps(fexps, bool),fannot)) -> + | E_record fexps -> rewrap (E_record - (FES_aux (FES_Fexps - (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) -> - FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot))) - | E_record_update (re,(FES_aux (FES_Fexps(fexps, bool),fannot))) -> + (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) -> + FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps)) + | E_record_update (re, fexps) -> rewrap (E_record_update ((rewrite re), - (FES_aux (FES_Fexps - (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) -> - FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot)))) + (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) -> + FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps))) | E_field(exp,id) -> rewrap (E_field(rewrite exp,id)) | E_case (exp,pexps) -> rewrap (E_case (rewrite exp, List.map (rewrite_pexp rewriters) pexps)) @@ -319,8 +316,8 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot)) as orig_exp) = | E_assert(e1,e2) -> rewrap (E_assert(rewrite e1,rewrite e2)) | E_var (lexp, e1, e2) -> rewrap (E_var (rewriters.rewrite_lexp rewriters lexp, rewriters.rewrite_exp rewriters e1, rewriters.rewrite_exp rewriters e2)) - | E_internal_return _ -> raise (Reporting_basic.err_unreachable l __POS__ "Internal return found before it should have been introduced") - | E_internal_plet _ -> raise (Reporting_basic.err_unreachable l __POS__ " Internal plet found before it should have been introduced") + | E_internal_return _ -> raise (Reporting.err_unreachable l __POS__ "Internal return found before it should have been introduced") + | E_internal_plet _ -> raise (Reporting.err_unreachable l __POS__ " Internal plet found before it should have been introduced") | _ -> rewrap exp let rewrite_let rewriters (LB_aux(letbind,(l,annot))) = @@ -349,7 +346,14 @@ let rewrite_lexp rewriters (LEXP_aux(lexp,(l,annot))) = let rewrite_fun rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) = let rewrite_funcl (FCL_aux (FCL_Funcl(id,pexp),(l,annot))) = (FCL_aux (FCL_Funcl (id, rewrite_pexp rewriters pexp),(l,annot))) - in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot)) + in + let recopt = match recopt with + | Rec_aux (Rec_nonrec, l) -> Rec_aux (Rec_nonrec, l) + | Rec_aux (Rec_rec, l) -> Rec_aux (Rec_rec, l) + | Rec_aux (Rec_measure (pat,exp),l) -> + Rec_aux (Rec_measure (rewrite_pat rewriters pat, rewrite_exp rewriters exp),l) + in + FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot)) let rewrite_def rewriters d = match d with | DEF_reg_dec (DEC_aux (DEC_config (id, typ, exp), annot)) -> @@ -358,7 +362,8 @@ let rewrite_def rewriters d = match d with | DEF_fundef fdef -> DEF_fundef (rewriters.rewrite_fun rewriters fdef) | DEF_internal_mutrec fdefs -> DEF_internal_mutrec (List.map (rewriters.rewrite_fun rewriters) fdefs) | DEF_val letbind -> DEF_val (rewriters.rewrite_let rewriters letbind) - | DEF_scattered _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown __POS__ "DEF_scattered survived to rewritter") + | DEF_pragma (pragma, arg, l) -> DEF_pragma (pragma, arg, l) + | DEF_scattered _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "DEF_scattered survived to rewritter") let rewrite_defs_base rewriters (Defs defs) = let rec rewrite ds = match ds with @@ -474,9 +479,9 @@ let id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg = ; fP_Fpat = (fun (id,pat) -> FP_Fpat (id,pat)) } -type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, +type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, 'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind, - 'pat,'pat_aux,'fpat,'fpat_aux) exp_alg = + 'pat,'pat_aux,'fpat,'fpat_aux) exp_alg = { e_block : 'exp list -> 'exp_aux ; e_nondet : 'exp list -> 'exp_aux ; e_id : id -> 'exp_aux @@ -497,8 +502,8 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, ; e_vector_append : 'exp * 'exp -> 'exp_aux ; e_list : 'exp list -> 'exp_aux ; e_cons : 'exp * 'exp -> 'exp_aux - ; e_record : 'fexps -> 'exp_aux - ; e_record_update : 'exp * 'fexps -> 'exp_aux + ; e_record : 'fexp list -> 'exp_aux + ; e_record_update : 'exp * 'fexp list -> 'exp_aux ; e_field : 'exp * id -> 'exp_aux ; e_case : 'exp * 'pexp list -> 'exp_aux ; e_try : 'exp * 'pexp list -> 'exp_aux @@ -527,8 +532,6 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux, ; lEXP_aux : 'lexp_aux * 'a annot -> 'lexp ; fE_Fexp : id * 'exp -> 'fexp_aux ; fE_aux : 'fexp_aux * 'a annot -> 'fexp - ; fES_Fexps : 'fexp list * bool -> 'fexps_aux - ; fES_aux : 'fexps_aux * 'a annot -> 'fexps ; def_val_empty : 'opt_default_aux ; def_val_dec : 'exp -> 'opt_default_aux ; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default @@ -566,8 +569,8 @@ let rec fold_exp_aux alg = function | E_vector_append (e1,e2) -> alg.e_vector_append (fold_exp alg e1, fold_exp alg e2) | E_list es -> alg.e_list (List.map (fold_exp alg) es) | E_cons (e1,e2) -> alg.e_cons (fold_exp alg e1, fold_exp alg e2) - | E_record fexps -> alg.e_record (fold_fexps alg fexps) - | E_record_update (e,fexps) -> alg.e_record_update (fold_exp alg e, fold_fexps alg fexps) + | E_record fexps -> alg.e_record (List.map (fold_fexp alg) fexps) + | E_record_update (e,fexps) -> alg.e_record_update (fold_exp alg e, List.map (fold_fexp alg) fexps) | E_field (e,id) -> alg.e_field (fold_exp alg e, id) | E_case (e,pexps) -> alg.e_case (fold_exp alg e, List.map (fold_pexp alg) pexps) | E_try (e,pexps) -> alg.e_try (fold_exp alg e, List.map (fold_pexp alg) pexps) @@ -601,8 +604,6 @@ and fold_lexp alg (LEXP_aux (lexp_aux,annot)) = alg.lEXP_aux (fold_lexp_aux alg lexp_aux, annot) and fold_fexp_aux alg (FE_Fexp (id,e)) = alg.fE_Fexp (id, fold_exp alg e) and fold_fexp alg (FE_aux (fexp_aux,annot)) = alg.fE_aux (fold_fexp_aux alg fexp_aux,annot) -and fold_fexps_aux alg (FES_Fexps (fexps,b)) = alg.fES_Fexps (List.map (fold_fexp alg) fexps, b) -and fold_fexps alg (FES_aux (fexps_aux,annot)) = alg.fES_aux (fold_fexps_aux alg fexps_aux, annot) and fold_opt_default_aux alg = function | Def_val_empty -> alg.def_val_empty | Def_val_dec e -> alg.def_val_dec (fold_exp alg e) @@ -673,8 +674,6 @@ let id_exp_alg = ; lEXP_aux = (fun (lexp,annot) -> LEXP_aux (lexp,annot)) ; fE_Fexp = (fun (id,e) -> FE_Fexp (id,e)) ; fE_aux = (fun (fexp,annot) -> FE_aux (fexp,annot)) - ; fES_Fexps = (fun (fexps,b) -> FES_Fexps (fexps,b)) - ; fES_aux = (fun (fexp,annot) -> FES_aux (fexp,annot)) ; def_val_empty = Def_val_empty ; def_val_dec = (fun e -> Def_val_dec e) ; def_val_aux = (fun (defval,aux) -> Def_val_aux (defval,aux)) @@ -741,8 +740,12 @@ let compute_exp_alg bot join = ; e_vector_append = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_append (e1,e2))) ; e_list = split_join (fun es -> E_list es) ; e_cons = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_cons (e1,e2))) - ; e_record = (fun (vs,fexps) -> (vs, E_record fexps)) - ; e_record_update = (fun ((v1,e1),(vf,fexp)) -> (join v1 vf, E_record_update (e1,fexp))) + ; e_record = (fun fexps -> + let vs, fexps = List.split fexps in + (join_list vs, E_record fexps)) + ; e_record_update = (fun ((v1,e1),fexps) -> + let (vps,fexps) = List.split fexps in + (join_list (v1::vps), E_record_update (e1,fexps))) ; e_field = (fun ((v1,e1),id) -> (v1, E_field (e1,id))) ; e_case = (fun ((v1,e1),pexps) -> let (vps,pexps) = List.split pexps in @@ -782,10 +785,6 @@ let compute_exp_alg bot join = ; lEXP_aux = (fun ((vl,lexp),annot) -> (vl, LEXP_aux (lexp,annot))) ; fE_Fexp = (fun (id,(v,e)) -> (v, FE_Fexp (id,e))) ; fE_aux = (fun ((vf,fexp),annot) -> (vf, FE_aux (fexp,annot))) - ; fES_Fexps = (fun (fexps,b) -> - let (vs,fexps) = List.split fexps in - (join_list vs, FES_Fexps (fexps,b))) - ; fES_aux = (fun ((vf,fexp),annot) -> (vf, FES_aux (fexp,annot))) ; def_val_empty = (bot, Def_val_empty) ; def_val_dec = (fun (v,e) -> (v, Def_val_dec e)) ; def_val_aux = (fun ((v,defval),aux) -> (v, Def_val_aux (defval,aux))) @@ -842,8 +841,8 @@ let pure_exp_alg bot join = ; e_vector_append = (fun (v1,v2) -> join v1 v2) ; e_list = join_list ; e_cons = (fun (v1,v2) -> join v1 v2) - ; e_record = (fun vs -> vs) - ; e_record_update = (fun (v1,vf) -> join v1 vf) + ; e_record = (fun vs -> join_list vs) + ; e_record_update = (fun (v1,vf) -> join_list (v1::vf)) ; e_field = (fun (v1,id) -> v1) ; e_case = (fun (v1,vps) -> join_list (v1::vps)) ; e_try = (fun (v1,vps) -> join_list (v1::vps)) @@ -872,8 +871,6 @@ let pure_exp_alg bot join = ; lEXP_aux = (fun (vl,annot) -> vl) ; fE_Fexp = (fun (id,v) -> v) ; fE_aux = (fun (vf,annot) -> vf) - ; fES_Fexps = (fun (vs,b) -> join_list vs) - ; fES_aux = (fun (vf,annot) -> vf) ; def_val_empty = bot ; def_val_dec = (fun v -> v) ; def_val_aux = (fun (v,aux) -> v) |
