summaryrefslogtreecommitdiff
path: root/src/rewriter.ml
diff options
context:
space:
mode:
authorThomas Bauereiss2018-12-18 15:16:36 +0000
committerThomas Bauereiss2018-12-18 15:16:36 +0000
commit1766bf5e3628b5c45290a3353bec05823661b9d3 (patch)
treecae2f596d135074399cd304bb8e3dca1330a2aa8 /src/rewriter.ml
parentdf0e02bc0c8259962f25d4c175fa950391695ab6 (diff)
parent07a332c856b3ee9fe26a9cd47ea6005f9d579810 (diff)
Merge branch 'sail2' into monads
Diffstat (limited to 'src/rewriter.ml')
-rw-r--r--src/rewriter.ml73
1 files changed, 35 insertions, 38 deletions
diff --git a/src/rewriter.ml b/src/rewriter.ml
index a7505ca7..a70f6fab 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -64,11 +64,10 @@ type 'a rewriters = {
rewrite_defs : 'a rewriters -> 'a defs -> 'a defs;
}
-
let effect_of_fpat (FP_aux (_,(_,a))) = effect_of_annot a
let effect_of_lexp (LEXP_aux (_,(_,a))) = effect_of_annot a
let effect_of_fexp (FE_aux (_,(_,a))) = effect_of_annot a
-let effect_of_fexps (FES_aux (FES_Fexps (fexps,_),_)) =
+let effect_of_fexps fexps =
List.fold_left union_effects no_effect (List.map effect_of_fexp fexps)
let effect_of_opt_default (Def_val_aux (_,(_,a))) = effect_of_annot a
(* The typechecker does not seem to annotate pexps themselves *)
@@ -95,7 +94,7 @@ let lookup_generated_kid env kid =
let generated_kids typ = KidSet.filter is_kid_generated (tyvars_of_typ typ)
let resolve_generated_kids env typ =
- let subst_kid kid typ = typ_subst_kid kid (lookup_generated_kid env kid) typ in
+ let subst_kid kid typ = subst_kid typ_subst kid (lookup_generated_kid env kid) typ in
KidSet.fold subst_kid (generated_kids typ) typ
let rec remove_p_typ = function
@@ -103,7 +102,7 @@ let rec remove_p_typ = function
| pat -> pat
let add_p_typ typ (P_aux (paux, annot) as pat) =
- let typ' = resolve_generated_kids (pat_env_of pat) typ in
+ let typ' = resolve_generated_kids (env_of_pat pat) typ in
if KidSet.is_empty (generated_kids typ') then
P_aux (P_typ (typ', remove_p_typ pat), annot)
else pat
@@ -295,16 +294,14 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot)) as orig_exp) =
| E_vector_append (v1,v2) -> rewrap (E_vector_append (rewrite v1,rewrite v2))
| E_list exps -> rewrap (E_list (List.map rewrite exps))
| E_cons(h,t) -> rewrap (E_cons (rewrite h,rewrite t))
- | E_record (FES_aux (FES_Fexps(fexps, bool),fannot)) ->
+ | E_record fexps ->
rewrap (E_record
- (FES_aux (FES_Fexps
- (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) ->
- FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot)))
- | E_record_update (re,(FES_aux (FES_Fexps(fexps, bool),fannot))) ->
+ (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) ->
+ FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps))
+ | E_record_update (re, fexps) ->
rewrap (E_record_update ((rewrite re),
- (FES_aux (FES_Fexps
- (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) ->
- FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot))))
+ (List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) ->
+ FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps)))
| E_field(exp,id) -> rewrap (E_field(rewrite exp,id))
| E_case (exp,pexps) ->
rewrap (E_case (rewrite exp, List.map (rewrite_pexp rewriters) pexps))
@@ -319,8 +316,8 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot)) as orig_exp) =
| E_assert(e1,e2) -> rewrap (E_assert(rewrite e1,rewrite e2))
| E_var (lexp, e1, e2) ->
rewrap (E_var (rewriters.rewrite_lexp rewriters lexp, rewriters.rewrite_exp rewriters e1, rewriters.rewrite_exp rewriters e2))
- | E_internal_return _ -> raise (Reporting_basic.err_unreachable l __POS__ "Internal return found before it should have been introduced")
- | E_internal_plet _ -> raise (Reporting_basic.err_unreachable l __POS__ " Internal plet found before it should have been introduced")
+ | E_internal_return _ -> raise (Reporting.err_unreachable l __POS__ "Internal return found before it should have been introduced")
+ | E_internal_plet _ -> raise (Reporting.err_unreachable l __POS__ " Internal plet found before it should have been introduced")
| _ -> rewrap exp
let rewrite_let rewriters (LB_aux(letbind,(l,annot))) =
@@ -349,7 +346,14 @@ let rewrite_lexp rewriters (LEXP_aux(lexp,(l,annot))) =
let rewrite_fun rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) =
let rewrite_funcl (FCL_aux (FCL_Funcl(id,pexp),(l,annot))) =
(FCL_aux (FCL_Funcl (id, rewrite_pexp rewriters pexp),(l,annot)))
- in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot))
+ in
+ let recopt = match recopt with
+ | Rec_aux (Rec_nonrec, l) -> Rec_aux (Rec_nonrec, l)
+ | Rec_aux (Rec_rec, l) -> Rec_aux (Rec_rec, l)
+ | Rec_aux (Rec_measure (pat,exp),l) ->
+ Rec_aux (Rec_measure (rewrite_pat rewriters pat, rewrite_exp rewriters exp),l)
+ in
+ FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot))
let rewrite_def rewriters d = match d with
| DEF_reg_dec (DEC_aux (DEC_config (id, typ, exp), annot)) ->
@@ -358,7 +362,8 @@ let rewrite_def rewriters d = match d with
| DEF_fundef fdef -> DEF_fundef (rewriters.rewrite_fun rewriters fdef)
| DEF_internal_mutrec fdefs -> DEF_internal_mutrec (List.map (rewriters.rewrite_fun rewriters) fdefs)
| DEF_val letbind -> DEF_val (rewriters.rewrite_let rewriters letbind)
- | DEF_scattered _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown __POS__ "DEF_scattered survived to rewritter")
+ | DEF_pragma (pragma, arg, l) -> DEF_pragma (pragma, arg, l)
+ | DEF_scattered _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "DEF_scattered survived to rewritter")
let rewrite_defs_base rewriters (Defs defs) =
let rec rewrite ds = match ds with
@@ -474,9 +479,9 @@ let id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg =
; fP_Fpat = (fun (id,pat) -> FP_Fpat (id,pat))
}
-type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux,
+type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,
'opt_default_aux,'opt_default,'pexp,'pexp_aux,'letbind_aux,'letbind,
- 'pat,'pat_aux,'fpat,'fpat_aux) exp_alg =
+ 'pat,'pat_aux,'fpat,'fpat_aux) exp_alg =
{ e_block : 'exp list -> 'exp_aux
; e_nondet : 'exp list -> 'exp_aux
; e_id : id -> 'exp_aux
@@ -497,8 +502,8 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux,
; e_vector_append : 'exp * 'exp -> 'exp_aux
; e_list : 'exp list -> 'exp_aux
; e_cons : 'exp * 'exp -> 'exp_aux
- ; e_record : 'fexps -> 'exp_aux
- ; e_record_update : 'exp * 'fexps -> 'exp_aux
+ ; e_record : 'fexp list -> 'exp_aux
+ ; e_record_update : 'exp * 'fexp list -> 'exp_aux
; e_field : 'exp * id -> 'exp_aux
; e_case : 'exp * 'pexp list -> 'exp_aux
; e_try : 'exp * 'pexp list -> 'exp_aux
@@ -527,8 +532,6 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_aux,
; lEXP_aux : 'lexp_aux * 'a annot -> 'lexp
; fE_Fexp : id * 'exp -> 'fexp_aux
; fE_aux : 'fexp_aux * 'a annot -> 'fexp
- ; fES_Fexps : 'fexp list * bool -> 'fexps_aux
- ; fES_aux : 'fexps_aux * 'a annot -> 'fexps
; def_val_empty : 'opt_default_aux
; def_val_dec : 'exp -> 'opt_default_aux
; def_val_aux : 'opt_default_aux * 'a annot -> 'opt_default
@@ -566,8 +569,8 @@ let rec fold_exp_aux alg = function
| E_vector_append (e1,e2) -> alg.e_vector_append (fold_exp alg e1, fold_exp alg e2)
| E_list es -> alg.e_list (List.map (fold_exp alg) es)
| E_cons (e1,e2) -> alg.e_cons (fold_exp alg e1, fold_exp alg e2)
- | E_record fexps -> alg.e_record (fold_fexps alg fexps)
- | E_record_update (e,fexps) -> alg.e_record_update (fold_exp alg e, fold_fexps alg fexps)
+ | E_record fexps -> alg.e_record (List.map (fold_fexp alg) fexps)
+ | E_record_update (e,fexps) -> alg.e_record_update (fold_exp alg e, List.map (fold_fexp alg) fexps)
| E_field (e,id) -> alg.e_field (fold_exp alg e, id)
| E_case (e,pexps) -> alg.e_case (fold_exp alg e, List.map (fold_pexp alg) pexps)
| E_try (e,pexps) -> alg.e_try (fold_exp alg e, List.map (fold_pexp alg) pexps)
@@ -601,8 +604,6 @@ and fold_lexp alg (LEXP_aux (lexp_aux,annot)) =
alg.lEXP_aux (fold_lexp_aux alg lexp_aux, annot)
and fold_fexp_aux alg (FE_Fexp (id,e)) = alg.fE_Fexp (id, fold_exp alg e)
and fold_fexp alg (FE_aux (fexp_aux,annot)) = alg.fE_aux (fold_fexp_aux alg fexp_aux,annot)
-and fold_fexps_aux alg (FES_Fexps (fexps,b)) = alg.fES_Fexps (List.map (fold_fexp alg) fexps, b)
-and fold_fexps alg (FES_aux (fexps_aux,annot)) = alg.fES_aux (fold_fexps_aux alg fexps_aux, annot)
and fold_opt_default_aux alg = function
| Def_val_empty -> alg.def_val_empty
| Def_val_dec e -> alg.def_val_dec (fold_exp alg e)
@@ -673,8 +674,6 @@ let id_exp_alg =
; lEXP_aux = (fun (lexp,annot) -> LEXP_aux (lexp,annot))
; fE_Fexp = (fun (id,e) -> FE_Fexp (id,e))
; fE_aux = (fun (fexp,annot) -> FE_aux (fexp,annot))
- ; fES_Fexps = (fun (fexps,b) -> FES_Fexps (fexps,b))
- ; fES_aux = (fun (fexp,annot) -> FES_aux (fexp,annot))
; def_val_empty = Def_val_empty
; def_val_dec = (fun e -> Def_val_dec e)
; def_val_aux = (fun (defval,aux) -> Def_val_aux (defval,aux))
@@ -741,8 +740,12 @@ let compute_exp_alg bot join =
; e_vector_append = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_append (e1,e2)))
; e_list = split_join (fun es -> E_list es)
; e_cons = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_cons (e1,e2)))
- ; e_record = (fun (vs,fexps) -> (vs, E_record fexps))
- ; e_record_update = (fun ((v1,e1),(vf,fexp)) -> (join v1 vf, E_record_update (e1,fexp)))
+ ; e_record = (fun fexps ->
+ let vs, fexps = List.split fexps in
+ (join_list vs, E_record fexps))
+ ; e_record_update = (fun ((v1,e1),fexps) ->
+ let (vps,fexps) = List.split fexps in
+ (join_list (v1::vps), E_record_update (e1,fexps)))
; e_field = (fun ((v1,e1),id) -> (v1, E_field (e1,id)))
; e_case = (fun ((v1,e1),pexps) ->
let (vps,pexps) = List.split pexps in
@@ -782,10 +785,6 @@ let compute_exp_alg bot join =
; lEXP_aux = (fun ((vl,lexp),annot) -> (vl, LEXP_aux (lexp,annot)))
; fE_Fexp = (fun (id,(v,e)) -> (v, FE_Fexp (id,e)))
; fE_aux = (fun ((vf,fexp),annot) -> (vf, FE_aux (fexp,annot)))
- ; fES_Fexps = (fun (fexps,b) ->
- let (vs,fexps) = List.split fexps in
- (join_list vs, FES_Fexps (fexps,b)))
- ; fES_aux = (fun ((vf,fexp),annot) -> (vf, FES_aux (fexp,annot)))
; def_val_empty = (bot, Def_val_empty)
; def_val_dec = (fun (v,e) -> (v, Def_val_dec e))
; def_val_aux = (fun ((v,defval),aux) -> (v, Def_val_aux (defval,aux)))
@@ -842,8 +841,8 @@ let pure_exp_alg bot join =
; e_vector_append = (fun (v1,v2) -> join v1 v2)
; e_list = join_list
; e_cons = (fun (v1,v2) -> join v1 v2)
- ; e_record = (fun vs -> vs)
- ; e_record_update = (fun (v1,vf) -> join v1 vf)
+ ; e_record = (fun vs -> join_list vs)
+ ; e_record_update = (fun (v1,vf) -> join_list (v1::vf))
; e_field = (fun (v1,id) -> v1)
; e_case = (fun (v1,vps) -> join_list (v1::vps))
; e_try = (fun (v1,vps) -> join_list (v1::vps))
@@ -872,8 +871,6 @@ let pure_exp_alg bot join =
; lEXP_aux = (fun (vl,annot) -> vl)
; fE_Fexp = (fun (id,v) -> v)
; fE_aux = (fun (vf,annot) -> vf)
- ; fES_Fexps = (fun (vs,b) -> join_list vs)
- ; fES_aux = (fun (vf,annot) -> vf)
; def_val_empty = bot
; def_val_dec = (fun v -> v)
; def_val_aux = (fun (v,aux) -> v)