diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/monomorphise.ml | 421 |
1 files changed, 395 insertions, 26 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 14ea30ba..dd6e32a7 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -2,6 +2,8 @@ open Parse_ast open Ast open Type_internal +let disable_const_propagation = ref false + (* TODO: put this somewhere common *) let id_to_string (Id_aux(id,l)) = @@ -16,7 +18,67 @@ let optmap v f = | None -> None | Some v -> Some (f v) -let disable_const_propagation = ref false +(* We need to be able to substitute in source types to monomorphise + datatype constructors *) + +let rec nexp_to_snexp ne = + let mkne n = Nexp_aux (n,Unknown) in + match ne.nexp with + | Nvar s -> mkne (Nexp_var (Kid_aux (Var s,Unknown))) + | Nid (s,ne) -> mkne (Nexp_id (Id_aux (Id s,Unknown))) + | Nconst i -> mkne (Nexp_constant (Big_int.int_of_big_int i)) + | Npos_inf + | Nneg_inf -> + raise (Reporting_basic.err_general Unknown ("Can't translate inf nexps back into source nexps")) + | Nadd (n1,n2) -> mkne (Nexp_sum (nexp_to_snexp n1, nexp_to_snexp n2)) + | Nsub (n1,n2) -> mkne (Nexp_minus (nexp_to_snexp n1, nexp_to_snexp n2)) + | Nmult (n1,n2) -> mkne (Nexp_times (nexp_to_snexp n1, nexp_to_snexp n2)) + | N2n (ne,_) -> mkne (Nexp_exp (nexp_to_snexp ne)) + | Npow _ -> + raise (Reporting_basic.err_general Unknown ("Can't translate pow nexps back into source nexps")) + | Nneg ne -> mkne (Nexp_neg (nexp_to_snexp ne)) + | Ninexact -> + raise (Reporting_basic.err_general Unknown ("Can't translate inexact nexps back into source nexps")) + | Nuvar _ -> + raise (Reporting_basic.err_general Unknown ("Can't translate uvar nexps back into source nexps")) + + +let subst_src_typ substs t = + let rec s_snexp (Nexp_aux (ne,l) as nexp) = + let re ne = Nexp_aux (ne,l) in + match ne with + | Nexp_var (Kid_aux (Var i,l)) -> + (match Envmap.apply substs i with + | Some (TA_nexp ne) -> nexp_to_snexp ne + | _ -> re (Nexp_var (Kid_aux(Var i,Generated l)))) + | Nexp_id _ + | Nexp_constant _ -> nexp + | Nexp_times (n1,n2) -> re (Nexp_times (s_snexp n1, s_snexp n2)) + | Nexp_sum (n1,n2) -> re (Nexp_sum (s_snexp n1, s_snexp n2)) + | Nexp_minus (n1,n2) -> re (Nexp_minus (s_snexp n1, s_snexp n2)) + | Nexp_exp ne -> re (Nexp_exp (s_snexp ne)) + | Nexp_neg ne -> re (Nexp_neg (s_snexp ne)) + in + let rec s_styp ((Typ_aux (t,l)) as ty) = + let re t = Typ_aux (t,l) in + match t with + | Typ_wild + | Typ_id _ + | Typ_var _ + -> ty + | Typ_fn (t1,t2,e) -> re (Typ_fn (s_styp t1, s_styp t2,e)) + | Typ_tup ts -> re (Typ_tup (List.map s_styp ts)) + | Typ_app (id,tas) -> re (Typ_app (id,List.map s_starg tas)) + and s_starg (Typ_arg_aux (ta,l) as targ) = + match ta with + | Typ_arg_nexp ne -> Typ_arg_aux (Typ_arg_nexp (s_snexp ne),l) + | Typ_arg_typ t -> Typ_arg_aux (Typ_arg_typ (s_styp t),l) + | Typ_arg_order _ + | Typ_arg_effect _ -> targ + in s_styp t + + + (* Based on current type checker's behaviour *) let pat_id_is_variable t_env id = @@ -25,9 +87,165 @@ let pat_id_is_variable t_env id = | Some (Base(_,Enum _,_,_,_,_)) -> false | _ -> true - -let nexp_subst substs exp = +let rec is_value t_env (E_aux (e,_)) = + match e with + | E_id id -> not (pat_id_is_variable t_env (id_to_string id)) + | E_lit _ -> true + | E_tuple es -> List.for_all (is_value t_env) es +(* TODO: more? *) + | _ -> false + +let is_pure (Effect_opt_aux (e,_)) = + match e with + | Effect_opt_pure -> true + | Effect_opt_effect (Effect_aux (Effect_set [],_)) -> true + | _ -> false + +let rec list_extract f = function + | [] -> None + | h::t -> match f h with None -> list_extract f t | Some v -> Some v + +let rec cross = function + | [] -> failwith "cross" + | [(x,l)] -> List.map (fun y -> [(x,y)]) l + | (x,l)::t -> + let t' = cross t in + List.concat (List.map (fun y -> List.map (fun l' -> (x,y)::l') t') l) + +(* Given a type for a constructor, work out which refinements we ought to produce *) +(* TODO collision avoidance *) +let split_src_type i ty (TypQ_aux (q,ql)) = + let rec size_nvars_nexp (Nexp_aux (ne,_)) = + match ne with + | Nexp_var (Kid_aux (Var v,_)) -> [v] + | Nexp_id _ + | Nexp_constant _ + -> [] + | Nexp_times (n1,n2) + | Nexp_sum (n1,n2) + | Nexp_minus (n1,n2) + -> size_nvars_nexp n1 @ size_nvars_nexp n2 + | Nexp_exp n + | Nexp_neg n + -> size_nvars_nexp n + in + let rec size_nvars_ty (Typ_aux (ty,l)) = + match ty with + | Typ_wild + | Typ_id _ + | Typ_var _ + -> [] + | Typ_fn _ -> + raise (Reporting_basic.err_general l ("Function type in constructor " ^ i)) + | Typ_tup ts -> List.concat (List.map size_nvars_ty ts) + | Typ_app (Id_aux (Id "vector",_), + [_;Typ_arg_aux (Typ_arg_nexp sz,_); + _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) -> + size_nvars_nexp sz + | Typ_app (_, tas) -> + [] (* We only support sizes for bitvectors mentioned explicitly, not any buried + inside another type *) + in + let nvars = List.sort_uniq (String.compare) (size_nvars_ty ty) in + match nvars with + | [] -> None + | sample::__ -> + (* Only check for constraints if we found a size to constrain *) + let qs = + match q with + | TypQ_no_forall -> + raise (Reporting_basic.err_general ql + ("No set constraint for variable " ^ sample ^ " in constructor " ^ i)) + | TypQ_tq qs -> qs + in + let find_set nvar = + match list_extract (function + | QI_aux (QI_const (NC_aux (NC_nat_set_bounded (Kid_aux (Var nvar',_),vals),_)),_) + -> if nvar = nvar' then Some vals else None + | _ -> None) qs with + | None -> + raise (Reporting_basic.err_general ql + ("No set constraint for variable " ^ nvar ^ " in constructor " ^ i)) + | Some vals -> (nvar,vals) + in + let nvar_sets = List.map find_set nvars in + let total_variants = List.fold_left ( * ) 1 (List.map (fun (_,l) -> List.length l) nvar_sets) in + let limit = 8 in + let () = if total_variants > limit then + raise (Reporting_basic.err_general ql + (string_of_int total_variants ^ "variants for constructor " ^ i ^ + "bigger than limit " ^ string_of_int limit)) else () + in + let variants = cross nvar_sets in + let name l = String.concat "_" (i::(List.map (fun (v,i) -> v ^ string_of_int i) l)) in + Some (List.map (fun l -> (l, name l)) variants) + +(* TODO: maybe fold this into subst_src_typ? *) +let inst_src_type insts ty = + let insts = List.map (fun (v,i) -> (v,TA_nexp {nexp=Nconst (Big_int.big_int_of_int i); imp_param=false})) insts in + let subst = Envmap.from_list insts in + subst_src_typ subst ty + +let reduce_nexp subst ne = + let ne = subst_n_with_env subst ne in + let rec eval ne = + match ne.nexp with + | Nconst i -> i + | Nadd (n1,n2) -> Big_int.add_big_int (eval n1) (eval n2) + | Nsub (n1,n2) -> Big_int.sub_big_int (eval n1) (eval n2) + | Nmult (n1,n2) -> Big_int.mult_big_int (eval n1) (eval n2) + | N2n (n,_) -> Big_int.power_int_positive_big_int 2 (eval n) + | Npow (n,i) -> Big_int.power_big_int_positive_int (eval n) i + | Nneg n -> Big_int.minus_big_int (eval n) + | _ -> + raise (Reporting_basic.err_general Unknown ("Couldn't turn nexp " ^ + n_to_string ne ^ " into concrete value")) + in Big_int.int_of_big_int (eval ne) + +(* Check to see if we need to monomorphise a use of a constructor. Currently + assumes that bitvector sizes are always given as a variable; don't yet handle + more general cases (e.g., 8 * var) *) + +let refine_constructor refinements i substs (E_aux (_,(l,_)) as arg) t = + let rec derive_vars t (E_aux (e,(l,tannot)) as exp) = + match t.t with + | Tapp ("vector", [_;TA_nexp {nexp = Nvar v};_;TA_typ {t=Tid "bit"}]) -> + (match tannot with + | Base ((_,{t=Tapp ("vector", [_;TA_nexp ne;_;TA_typ {t=Tid "bit"}])}),_,_,_,_,_) -> + [(v,reduce_nexp substs ne)] + | _ -> []) + | Tvar _ + | Tid _ + | Tfn _ + | Tapp _ + | Tuvar _ + -> [] + | Tabbrev (_,t) + | Toptions (t,_) + -> derive_vars t exp + | Ttup ts -> + match e with + | E_tuple es -> List.concat (List.map2 derive_vars ts es) + | _ -> [] (* TODO? *) + in + try + let irefinements = List.assoc i refinements in + let vars = List.sort_uniq (fun x y -> String.compare (fst x) (fst y)) (derive_vars t arg) in + try + Some (List.assoc vars irefinements) + with Not_found -> + (Reporting_basic.print_err false true l "Monomorphisation" + ("Failed to find a monomorphic constructor for " ^ i ^ " instance " ^ + match vars with [] -> "<empty>" + | _ -> String.concat "," (List.map (fun (x,y) -> x ^ "=" ^ string_of_int y) vars)); None) + with Not_found -> None + + +(* Substitute found nexps for variables in an expression, and rename constructors to reflect + specialisation *) + +let nexp_subst_fns t_env substs refinements = let s_t t = typ_subst substs true t in (* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in hopefully don't need this anyway *) @@ -68,7 +286,23 @@ let nexp_subst substs exp = | E_internal_exp_user ((l1,annot1),(l2,annot2)) -> re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2))) | E_cast (t,e') -> re (E_cast (t, s_exp e')) - | E_app (id,es) -> re (E_app (id,List.map s_exp es)) + | E_app (id,es) -> + let es' = List.map s_exp es in + let arg = + match es' with + | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(l,simple_annot unit_t)) + | [e] -> e + | _ -> E_aux (E_tuple es',(l,NoTyp)) + in + let i = id_to_string id in + let id' = + match Envmap.apply t_env i with + | Some (Base((t_params,{t=Tfn(inty,outty,_,_)}),Constructor _,_,_,_,_)) -> + (match refine_constructor refinements i substs arg inty with + | None -> id + | Some i -> Id_aux (Id i,Generated l)) + | _ -> id + in re (E_app (id',es')) | E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2)) | E_tuple es -> re (E_tuple (List.map s_exp es)) | E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3)) @@ -123,7 +357,9 @@ let nexp_subst substs exp = | LEXP_vector (le,e) -> re (LEXP_vector (s_lexp le, s_exp e)) | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2)) | LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id)) - in s_exp exp + in (s_pat,s_exp) +let nexp_subst_pat t_env substs refinements = fst (nexp_subst_fns t_env substs refinements) +let nexp_subst_exp t_env substs refinements = snd (nexp_subst_fns t_env substs refinements) let bindings_from_pat t_env p = let rec aux_pat (P_aux (p,annot)) = @@ -151,10 +387,47 @@ let remove_bound t_env env pat = let bound = bindings_from_pat t_env pat in List.fold_left (fun sub v -> Envmap.remove env v) env bound +(* We may need to split up a pattern match if (1) we've been told to case split + on a variable by the user, or (2) we monomorphised a constructor that's used + in the pattern. *) +type split = + | NoSplit + | VarSplit of (tannot pat * (string * tannot Ast.exp)) list + | ConstrSplit of (tannot pat * t_arg Envmap.t) list let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = + let split_constructors (Defs defs) = + let sc_type_union q (Tu_aux (tu,l) as tua) = + match tu with + | Tu_id id -> [],[tua] + | Tu_ty_id (ty,id) -> + let i = id_to_string id in + (match split_src_type i ty q with + | None -> ([],[Tu_aux (Tu_ty_id (ty,id),l)]) + | Some variants -> + ([(i,variants)], + List.map (fun (insts, i') -> Tu_aux (Tu_ty_id (inst_src_type insts ty,Id_aux (Id i',Generated l)),Generated l)) variants)) + in + let sc_type_def ((TD_aux (tda,annot)) as td) = + match tda with + | TD_variant (id,nscm,quant,tus,flag) -> + let (refinements, tus') = List.split (List.map (sc_type_union quant) tus) in + (List.concat refinements, TD_aux (TD_variant (id,nscm,quant,List.concat tus',flag),annot)) + | _ -> ([],td) + in + let sc_def d = + match d with + | DEF_type td -> let (refinements,td') = sc_type_def td in (refinements, DEF_type td') + | _ -> ([], d) + in + let (refinements, defs') = List.split (List.map sc_def defs) + in (List.concat refinements, Defs defs') + in - let can_match (E_aux (e,(l,annot)) as exp) cases = + let (refinements, defs') = split_constructors defs in + + (* Attempt simple pattern matches *) + let can_match (E_aux (e,(l,annot)) as exp0) cases = match e with | E_id id -> let i = id_to_string id in @@ -164,10 +437,15 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = match cases with | [] -> (Reporting_basic.print_err false true l "Monomorphisation" ("Failed to find a case for " ^ i); None) - | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some exp + | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some (exp,[]) + | (Pat_aux (Pat_exp (P_aux (P_typ (_,p),_),exp),ann))::tl -> + findpat ((Pat_aux (Pat_exp (p,exp),ann))::tl) + | (Pat_aux (Pat_exp (P_aux (P_id id',_),exp),_))::tl + when pat_id_is_variable t_env (id_to_string id') -> + Some (exp, [(id_to_string id', exp0)]) | (Pat_aux (Pat_exp (P_aux (P_id id',_),exp),_))::tl | (Pat_aux (Pat_exp (P_aux (P_app (id',[]),_),exp),_))::tl -> - if i = id_to_string id' then Some exp else findpat tl + if i = id_to_string id' then Some (exp,[]) else findpat tl | (Pat_aux (Pat_exp (P_aux (_,(l',_)),_),_))::_ -> (Reporting_basic.print_err false true l' "Monomorphisation" "Unexpected kind of pattern for enumeration"; None) @@ -179,11 +457,16 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = match cases with | [] -> (Reporting_basic.print_err false true l "Monomorphisation" ("Failed to find a case for bit"); None) - | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some exp + | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some (exp,[]) + | (Pat_aux (Pat_exp (P_aux (P_typ (_,p),_),exp),ann))::tl -> + findpat ((Pat_aux (Pat_exp (p,exp),ann))::tl) + | (Pat_aux (Pat_exp (P_aux (P_id id',_),exp),_))::tl + when pat_id_is_variable t_env (id_to_string id') -> + Some (exp, [(id_to_string id', exp0)]) | (Pat_aux (Pat_exp (P_aux (P_lit (L_aux (lit, _)),_),exp),_))::tl -> (match bit,lit with - | (L_zero | L_false), (L_zero | L_false) -> Some exp - | (L_one | L_true ), (L_one | L_true ) -> Some exp + | (L_zero | L_false), (L_zero | L_false) -> Some (exp,[]) + | (L_one | L_true ), (L_one | L_true ) -> Some (exp,[]) | _ -> findpat tl) | (Pat_aux (Pat_exp (P_aux (_,(l',_)),_),_))::_ -> (Reporting_basic.print_err false true l' "Monomorphisation" @@ -192,6 +475,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = | _ -> None in + (* Similarly, simple conditionals *) (* TODO: doublecheck *) let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) = match l1,l2 with @@ -215,6 +499,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = | _ -> None in + (* Extract nvar substitution by comparing two types *) let build_nexp_subst l t1 t2 = let rec from_types t1 t2 = let t1 = match t1.t with Tabbrev(_,t) -> t | _ -> t1 in @@ -260,7 +545,8 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = let rec const_prop_exp substs ((E_aux (e,(l,annot))) as exp) = let re e = E_aux (e,(l,annot)) in match e with - (* TODO: are there circumstances in which we should get rid of these? *) + (* TODO: are there more circumstances in which we should get rid of these? *) + | E_block [e] -> const_prop_exp substs e | E_block es -> re (E_block (List.map (const_prop_exp substs) es)) | E_nondet es -> re (E_nondet (List.map (const_prop_exp substs) es)) @@ -276,7 +562,11 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = | E_comment _ -> exp | E_cast (t,e') -> re (E_cast (t, const_prop_exp substs e')) - | E_app (id,es) -> re (E_app (id,List.map (const_prop_exp substs) es)) + | E_app (id,es) -> + let es' = List.map (const_prop_exp substs) es in + (match const_prop_try_fn (id,es') with + | None -> re (E_app (id,es')) + | Some r -> r) | E_app_infix (e1,id,e2) -> let e1',e2' = const_prop_exp substs e1,const_prop_exp substs e2 in (match try_app_infix (l,annot) e1' (id_to_string id) e2' with @@ -311,9 +601,10 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = let e' = const_prop_exp substs e in (match can_match e' cases with | None -> re (E_case (e', List.map (const_prop_pexp substs) cases)) - | Some (E_aux (_,(_,annot')) as exp) -> + | Some (E_aux (_,(_,annot')) as exp,newbindings) -> + let substs' = Envmap.union substs (Envmap.from_list newbindings) in nexp_substs := build_nexp_subst l annot annot' @ !nexp_substs; - const_prop_exp substs exp) + const_prop_exp substs' exp) | E_let (lb,e) -> let (lb',substs') = const_prop_letbind substs lb in re (E_let (lb', const_prop_exp substs' e)) @@ -357,6 +648,39 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = | LEXP_vector (le,e) -> re (LEXP_vector (const_prop_lexp substs le, const_prop_exp substs e)) | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (const_prop_lexp substs le, const_prop_exp substs e1, const_prop_exp substs e2)) | LEXP_field (le,id) -> re (LEXP_field (const_prop_lexp substs le, id)) + (* Reduce a function when + 1. all arguments are values, + 2. the function is pure, + 3. the result is a value + (and 4. the function is not scattered, but that's not terribly important) + to try and keep execution time and the results managable. + *) + and const_prop_try_fn (id,args) = + if not (List.for_all (is_value t_env) args) then + None + else + let i = id_to_string id in + let Defs ds = defs in + match list_extract (function + | (DEF_fundef (FD_aux (FD_function (_,_,eff,((FCL_aux (FCL_Funcl (id',_,_),_))::_ as fcls)),_))) + -> if i = id_to_string id' then Some (eff,fcls) else None + | _ -> None) ds with + | None -> None + | Some (eff,_) when not (is_pure eff) -> None + | Some (_,fcls) -> + let arg = match args with + | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(Unknown,simple_annot unit_t)) + | [e] -> e + | _ -> E_aux (E_tuple args,(Unknown,NoTyp)) in + let cases = List.map (function + | FCL_aux (FCL_Funcl (_,pat,exp), ann) -> Pat_aux (Pat_exp (pat,exp),ann)) + fcls in + match can_match arg cases with + | Some (exp,bindings) -> + let substs = Envmap.from_list bindings in + let result = const_prop_exp substs exp in + if is_value t_env result then Some result else None + | None -> None in let subst_exp subst exp = @@ -374,7 +698,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = (* Substitute what we've learned about nvars into the term *) let nsubsts = Envmap.from_list (List.map (fun (id,ne) -> (id,TA_nexp ne)) !nexp_substs) in let () = nexp_substs := [] in - nexp_subst nsubsts exp' + nexp_subst_exp t_env nsubsts refinements exp' in @@ -428,7 +752,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = Filename.basename p.Lexing.pos_fname = filename && p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum) ls in - + let split_pat var p = let rec list f = function | [] -> None @@ -485,14 +809,45 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = relist spl (fun ps -> P_list ps) ps in spl p in - - let map_pat (P_aux (_,(l,_)) as p) = + + let map_pat_by_loc (P_aux (p,(l,_)) as pat) = match match_l l with | [] -> None - | [(_,var)] -> split_pat var p + | [(_,var)] -> split_pat var pat | lvs -> raise (Reporting_basic.err_general l ("Multiple variables to split on: " ^ String.concat ", " (List.map snd lvs))) in + let map_pat (P_aux (p,(l,tannot)) as pat) = + match map_pat_by_loc pat with + | Some l -> VarSplit l + | None -> + match p with + | P_app (id,args) -> + (try + let i = id_to_string id in + let variants = List.assoc i refinements in + let constr_tannot = + match Envmap.apply t_env i with + | Some (Base ((_,{t=Tfn(_,outt,_,_)}),_,_,_,_,_)) -> simple_annot outt + | _ -> raise (Reporting_basic.err_general l + ("Constructor missing from environment: " ^ i)) + in + let varmap = build_nexp_subst l constr_tannot tannot in + let map_inst (insts,i') = + let insts = List.map (fun (v,i) -> + ((match List.assoc v varmap with + | {nexp=Nvar s} -> s + | _ -> raise (Reporting_basic.err_general l + ("Constructor parameter not a variable: " ^ v))), + TA_nexp {nexp=Nconst (Big_int.big_int_of_int i);imp_param=false})) + insts in + P_aux (P_app (Id_aux (Id i',Generated l),args),(Generated l,tannot)), + Envmap.from_list insts + in + ConstrSplit (List.map map_inst variants) + with Not_found -> NoSplit) + | _ -> NoSplit + in let check_single_pat (P_aux (_,(l,_)) as p) = match match_l l with @@ -560,12 +915,19 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = FE_aux (FE_Fexp (id,map_exp e),annot) and map_pexp (Pat_aux (Pat_exp (p,e),l)) = match map_pat p with - | None -> [Pat_aux (Pat_exp (p,map_exp e),l)] - | Some patsubsts -> + | NoSplit -> [Pat_aux (Pat_exp (p,map_exp e),l)] + | VarSplit patsubsts -> List.map (fun (pat',subst) -> let exp' = subst_exp subst e in Pat_aux (Pat_exp (pat', map_exp exp'),l)) patsubsts + | ConstrSplit patnsubsts -> + List.map (fun (pat',nsubst) -> + (* Leave refinements to later *) + let pat' = nexp_subst_pat t_env nsubst [] pat' in + let exp' = nexp_subst_exp t_env nsubst [] e in + Pat_aux (Pat_exp (pat', map_exp exp'),l) + ) patnsubsts and map_letbind (LB_aux (lb,annot)) = match lb with | LB_val_explicit (tysch,p,e) -> LB_aux (LB_val_explicit (tysch,check_single_pat p,map_exp e), annot) @@ -585,12 +947,19 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = let map_funcl (FCL_aux (FCL_Funcl (id,pat,exp),annot)) = match map_pat pat with - | None -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)] - | Some patsubsts -> + | NoSplit -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)] + | VarSplit patsubsts -> List.map (fun (pat',subst) -> let exp' = subst_exp subst exp in FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot)) patsubsts + | ConstrSplit patnsubsts -> + List.map (fun (pat',nsubst) -> + (* Leave refinements to later *) + let pat' = nexp_subst_pat t_env nsubst [] pat' in + let exp' = nexp_subst_exp t_env nsubst [] exp in + FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot) + ) patnsubsts in let map_fundef (FD_aux (FD_function (r,t,e,fcls),annot)) = @@ -617,5 +986,5 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = in Defs (List.concat (List.map map_def defs)) - - in map_locs splits defs + in + map_locs splits defs' |
