summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml421
1 files changed, 395 insertions, 26 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 14ea30ba..dd6e32a7 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -2,6 +2,8 @@ open Parse_ast
open Ast
open Type_internal
+let disable_const_propagation = ref false
+
(* TODO: put this somewhere common *)
let id_to_string (Id_aux(id,l)) =
@@ -16,7 +18,67 @@ let optmap v f =
| None -> None
| Some v -> Some (f v)
-let disable_const_propagation = ref false
+(* We need to be able to substitute in source types to monomorphise
+ datatype constructors *)
+
+let rec nexp_to_snexp ne =
+ let mkne n = Nexp_aux (n,Unknown) in
+ match ne.nexp with
+ | Nvar s -> mkne (Nexp_var (Kid_aux (Var s,Unknown)))
+ | Nid (s,ne) -> mkne (Nexp_id (Id_aux (Id s,Unknown)))
+ | Nconst i -> mkne (Nexp_constant (Big_int.int_of_big_int i))
+ | Npos_inf
+ | Nneg_inf ->
+ raise (Reporting_basic.err_general Unknown ("Can't translate inf nexps back into source nexps"))
+ | Nadd (n1,n2) -> mkne (Nexp_sum (nexp_to_snexp n1, nexp_to_snexp n2))
+ | Nsub (n1,n2) -> mkne (Nexp_minus (nexp_to_snexp n1, nexp_to_snexp n2))
+ | Nmult (n1,n2) -> mkne (Nexp_times (nexp_to_snexp n1, nexp_to_snexp n2))
+ | N2n (ne,_) -> mkne (Nexp_exp (nexp_to_snexp ne))
+ | Npow _ ->
+ raise (Reporting_basic.err_general Unknown ("Can't translate pow nexps back into source nexps"))
+ | Nneg ne -> mkne (Nexp_neg (nexp_to_snexp ne))
+ | Ninexact ->
+ raise (Reporting_basic.err_general Unknown ("Can't translate inexact nexps back into source nexps"))
+ | Nuvar _ ->
+ raise (Reporting_basic.err_general Unknown ("Can't translate uvar nexps back into source nexps"))
+
+
+let subst_src_typ substs t =
+ let rec s_snexp (Nexp_aux (ne,l) as nexp) =
+ let re ne = Nexp_aux (ne,l) in
+ match ne with
+ | Nexp_var (Kid_aux (Var i,l)) ->
+ (match Envmap.apply substs i with
+ | Some (TA_nexp ne) -> nexp_to_snexp ne
+ | _ -> re (Nexp_var (Kid_aux(Var i,Generated l))))
+ | Nexp_id _
+ | Nexp_constant _ -> nexp
+ | Nexp_times (n1,n2) -> re (Nexp_times (s_snexp n1, s_snexp n2))
+ | Nexp_sum (n1,n2) -> re (Nexp_sum (s_snexp n1, s_snexp n2))
+ | Nexp_minus (n1,n2) -> re (Nexp_minus (s_snexp n1, s_snexp n2))
+ | Nexp_exp ne -> re (Nexp_exp (s_snexp ne))
+ | Nexp_neg ne -> re (Nexp_neg (s_snexp ne))
+ in
+ let rec s_styp ((Typ_aux (t,l)) as ty) =
+ let re t = Typ_aux (t,l) in
+ match t with
+ | Typ_wild
+ | Typ_id _
+ | Typ_var _
+ -> ty
+ | Typ_fn (t1,t2,e) -> re (Typ_fn (s_styp t1, s_styp t2,e))
+ | Typ_tup ts -> re (Typ_tup (List.map s_styp ts))
+ | Typ_app (id,tas) -> re (Typ_app (id,List.map s_starg tas))
+ and s_starg (Typ_arg_aux (ta,l) as targ) =
+ match ta with
+ | Typ_arg_nexp ne -> Typ_arg_aux (Typ_arg_nexp (s_snexp ne),l)
+ | Typ_arg_typ t -> Typ_arg_aux (Typ_arg_typ (s_styp t),l)
+ | Typ_arg_order _
+ | Typ_arg_effect _ -> targ
+ in s_styp t
+
+
+
(* Based on current type checker's behaviour *)
let pat_id_is_variable t_env id =
@@ -25,9 +87,165 @@ let pat_id_is_variable t_env id =
| Some (Base(_,Enum _,_,_,_,_))
-> false
| _ -> true
-
-let nexp_subst substs exp =
+let rec is_value t_env (E_aux (e,_)) =
+ match e with
+ | E_id id -> not (pat_id_is_variable t_env (id_to_string id))
+ | E_lit _ -> true
+ | E_tuple es -> List.for_all (is_value t_env) es
+(* TODO: more? *)
+ | _ -> false
+
+let is_pure (Effect_opt_aux (e,_)) =
+ match e with
+ | Effect_opt_pure -> true
+ | Effect_opt_effect (Effect_aux (Effect_set [],_)) -> true
+ | _ -> false
+
+let rec list_extract f = function
+ | [] -> None
+ | h::t -> match f h with None -> list_extract f t | Some v -> Some v
+
+let rec cross = function
+ | [] -> failwith "cross"
+ | [(x,l)] -> List.map (fun y -> [(x,y)]) l
+ | (x,l)::t ->
+ let t' = cross t in
+ List.concat (List.map (fun y -> List.map (fun l' -> (x,y)::l') t') l)
+
+(* Given a type for a constructor, work out which refinements we ought to produce *)
+(* TODO collision avoidance *)
+let split_src_type i ty (TypQ_aux (q,ql)) =
+ let rec size_nvars_nexp (Nexp_aux (ne,_)) =
+ match ne with
+ | Nexp_var (Kid_aux (Var v,_)) -> [v]
+ | Nexp_id _
+ | Nexp_constant _
+ -> []
+ | Nexp_times (n1,n2)
+ | Nexp_sum (n1,n2)
+ | Nexp_minus (n1,n2)
+ -> size_nvars_nexp n1 @ size_nvars_nexp n2
+ | Nexp_exp n
+ | Nexp_neg n
+ -> size_nvars_nexp n
+ in
+ let rec size_nvars_ty (Typ_aux (ty,l)) =
+ match ty with
+ | Typ_wild
+ | Typ_id _
+ | Typ_var _
+ -> []
+ | Typ_fn _ ->
+ raise (Reporting_basic.err_general l ("Function type in constructor " ^ i))
+ | Typ_tup ts -> List.concat (List.map size_nvars_ty ts)
+ | Typ_app (Id_aux (Id "vector",_),
+ [_;Typ_arg_aux (Typ_arg_nexp sz,_);
+ _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
+ size_nvars_nexp sz
+ | Typ_app (_, tas) ->
+ [] (* We only support sizes for bitvectors mentioned explicitly, not any buried
+ inside another type *)
+ in
+ let nvars = List.sort_uniq (String.compare) (size_nvars_ty ty) in
+ match nvars with
+ | [] -> None
+ | sample::__ ->
+ (* Only check for constraints if we found a size to constrain *)
+ let qs =
+ match q with
+ | TypQ_no_forall ->
+ raise (Reporting_basic.err_general ql
+ ("No set constraint for variable " ^ sample ^ " in constructor " ^ i))
+ | TypQ_tq qs -> qs
+ in
+ let find_set nvar =
+ match list_extract (function
+ | QI_aux (QI_const (NC_aux (NC_nat_set_bounded (Kid_aux (Var nvar',_),vals),_)),_)
+ -> if nvar = nvar' then Some vals else None
+ | _ -> None) qs with
+ | None ->
+ raise (Reporting_basic.err_general ql
+ ("No set constraint for variable " ^ nvar ^ " in constructor " ^ i))
+ | Some vals -> (nvar,vals)
+ in
+ let nvar_sets = List.map find_set nvars in
+ let total_variants = List.fold_left ( * ) 1 (List.map (fun (_,l) -> List.length l) nvar_sets) in
+ let limit = 8 in
+ let () = if total_variants > limit then
+ raise (Reporting_basic.err_general ql
+ (string_of_int total_variants ^ "variants for constructor " ^ i ^
+ "bigger than limit " ^ string_of_int limit)) else ()
+ in
+ let variants = cross nvar_sets in
+ let name l = String.concat "_" (i::(List.map (fun (v,i) -> v ^ string_of_int i) l)) in
+ Some (List.map (fun l -> (l, name l)) variants)
+
+(* TODO: maybe fold this into subst_src_typ? *)
+let inst_src_type insts ty =
+ let insts = List.map (fun (v,i) -> (v,TA_nexp {nexp=Nconst (Big_int.big_int_of_int i); imp_param=false})) insts in
+ let subst = Envmap.from_list insts in
+ subst_src_typ subst ty
+
+let reduce_nexp subst ne =
+ let ne = subst_n_with_env subst ne in
+ let rec eval ne =
+ match ne.nexp with
+ | Nconst i -> i
+ | Nadd (n1,n2) -> Big_int.add_big_int (eval n1) (eval n2)
+ | Nsub (n1,n2) -> Big_int.sub_big_int (eval n1) (eval n2)
+ | Nmult (n1,n2) -> Big_int.mult_big_int (eval n1) (eval n2)
+ | N2n (n,_) -> Big_int.power_int_positive_big_int 2 (eval n)
+ | Npow (n,i) -> Big_int.power_big_int_positive_int (eval n) i
+ | Nneg n -> Big_int.minus_big_int (eval n)
+ | _ ->
+ raise (Reporting_basic.err_general Unknown ("Couldn't turn nexp " ^
+ n_to_string ne ^ " into concrete value"))
+ in Big_int.int_of_big_int (eval ne)
+
+(* Check to see if we need to monomorphise a use of a constructor. Currently
+ assumes that bitvector sizes are always given as a variable; don't yet handle
+ more general cases (e.g., 8 * var) *)
+
+let refine_constructor refinements i substs (E_aux (_,(l,_)) as arg) t =
+ let rec derive_vars t (E_aux (e,(l,tannot)) as exp) =
+ match t.t with
+ | Tapp ("vector", [_;TA_nexp {nexp = Nvar v};_;TA_typ {t=Tid "bit"}]) ->
+ (match tannot with
+ | Base ((_,{t=Tapp ("vector", [_;TA_nexp ne;_;TA_typ {t=Tid "bit"}])}),_,_,_,_,_) ->
+ [(v,reduce_nexp substs ne)]
+ | _ -> [])
+ | Tvar _
+ | Tid _
+ | Tfn _
+ | Tapp _
+ | Tuvar _
+ -> []
+ | Tabbrev (_,t)
+ | Toptions (t,_)
+ -> derive_vars t exp
+ | Ttup ts ->
+ match e with
+ | E_tuple es -> List.concat (List.map2 derive_vars ts es)
+ | _ -> [] (* TODO? *)
+ in
+ try
+ let irefinements = List.assoc i refinements in
+ let vars = List.sort_uniq (fun x y -> String.compare (fst x) (fst y)) (derive_vars t arg) in
+ try
+ Some (List.assoc vars irefinements)
+ with Not_found ->
+ (Reporting_basic.print_err false true l "Monomorphisation"
+ ("Failed to find a monomorphic constructor for " ^ i ^ " instance " ^
+ match vars with [] -> "<empty>"
+ | _ -> String.concat "," (List.map (fun (x,y) -> x ^ "=" ^ string_of_int y) vars)); None)
+ with Not_found -> None
+
+
+(* Substitute found nexps for variables in an expression, and rename constructors to reflect
+ specialisation *)
+
+let nexp_subst_fns t_env substs refinements =
let s_t t = typ_subst substs true t in
(* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in
hopefully don't need this anyway *)
@@ -68,7 +286,23 @@ let nexp_subst substs exp =
| E_internal_exp_user ((l1,annot1),(l2,annot2)) ->
re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2)))
| E_cast (t,e') -> re (E_cast (t, s_exp e'))
- | E_app (id,es) -> re (E_app (id,List.map s_exp es))
+ | E_app (id,es) ->
+ let es' = List.map s_exp es in
+ let arg =
+ match es' with
+ | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(l,simple_annot unit_t))
+ | [e] -> e
+ | _ -> E_aux (E_tuple es',(l,NoTyp))
+ in
+ let i = id_to_string id in
+ let id' =
+ match Envmap.apply t_env i with
+ | Some (Base((t_params,{t=Tfn(inty,outty,_,_)}),Constructor _,_,_,_,_)) ->
+ (match refine_constructor refinements i substs arg inty with
+ | None -> id
+ | Some i -> Id_aux (Id i,Generated l))
+ | _ -> id
+ in re (E_app (id',es'))
| E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2))
| E_tuple es -> re (E_tuple (List.map s_exp es))
| E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3))
@@ -123,7 +357,9 @@ let nexp_subst substs exp =
| LEXP_vector (le,e) -> re (LEXP_vector (s_lexp le, s_exp e))
| LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2))
| LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id))
- in s_exp exp
+ in (s_pat,s_exp)
+let nexp_subst_pat t_env substs refinements = fst (nexp_subst_fns t_env substs refinements)
+let nexp_subst_exp t_env substs refinements = snd (nexp_subst_fns t_env substs refinements)
let bindings_from_pat t_env p =
let rec aux_pat (P_aux (p,annot)) =
@@ -151,10 +387,47 @@ let remove_bound t_env env pat =
let bound = bindings_from_pat t_env pat in
List.fold_left (fun sub v -> Envmap.remove env v) env bound
+(* We may need to split up a pattern match if (1) we've been told to case split
+ on a variable by the user, or (2) we monomorphised a constructor that's used
+ in the pattern. *)
+type split =
+ | NoSplit
+ | VarSplit of (tannot pat * (string * tannot Ast.exp)) list
+ | ConstrSplit of (tannot pat * t_arg Envmap.t) list
let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
+ let split_constructors (Defs defs) =
+ let sc_type_union q (Tu_aux (tu,l) as tua) =
+ match tu with
+ | Tu_id id -> [],[tua]
+ | Tu_ty_id (ty,id) ->
+ let i = id_to_string id in
+ (match split_src_type i ty q with
+ | None -> ([],[Tu_aux (Tu_ty_id (ty,id),l)])
+ | Some variants ->
+ ([(i,variants)],
+ List.map (fun (insts, i') -> Tu_aux (Tu_ty_id (inst_src_type insts ty,Id_aux (Id i',Generated l)),Generated l)) variants))
+ in
+ let sc_type_def ((TD_aux (tda,annot)) as td) =
+ match tda with
+ | TD_variant (id,nscm,quant,tus,flag) ->
+ let (refinements, tus') = List.split (List.map (sc_type_union quant) tus) in
+ (List.concat refinements, TD_aux (TD_variant (id,nscm,quant,List.concat tus',flag),annot))
+ | _ -> ([],td)
+ in
+ let sc_def d =
+ match d with
+ | DEF_type td -> let (refinements,td') = sc_type_def td in (refinements, DEF_type td')
+ | _ -> ([], d)
+ in
+ let (refinements, defs') = List.split (List.map sc_def defs)
+ in (List.concat refinements, Defs defs')
+ in
- let can_match (E_aux (e,(l,annot)) as exp) cases =
+ let (refinements, defs') = split_constructors defs in
+
+ (* Attempt simple pattern matches *)
+ let can_match (E_aux (e,(l,annot)) as exp0) cases =
match e with
| E_id id ->
let i = id_to_string id in
@@ -164,10 +437,15 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
match cases with
| [] -> (Reporting_basic.print_err false true l "Monomorphisation"
("Failed to find a case for " ^ i); None)
- | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some exp
+ | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some (exp,[])
+ | (Pat_aux (Pat_exp (P_aux (P_typ (_,p),_),exp),ann))::tl ->
+ findpat ((Pat_aux (Pat_exp (p,exp),ann))::tl)
+ | (Pat_aux (Pat_exp (P_aux (P_id id',_),exp),_))::tl
+ when pat_id_is_variable t_env (id_to_string id') ->
+ Some (exp, [(id_to_string id', exp0)])
| (Pat_aux (Pat_exp (P_aux (P_id id',_),exp),_))::tl
| (Pat_aux (Pat_exp (P_aux (P_app (id',[]),_),exp),_))::tl ->
- if i = id_to_string id' then Some exp else findpat tl
+ if i = id_to_string id' then Some (exp,[]) else findpat tl
| (Pat_aux (Pat_exp (P_aux (_,(l',_)),_),_))::_ ->
(Reporting_basic.print_err false true l' "Monomorphisation"
"Unexpected kind of pattern for enumeration"; None)
@@ -179,11 +457,16 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
match cases with
| [] -> (Reporting_basic.print_err false true l "Monomorphisation"
("Failed to find a case for bit"); None)
- | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some exp
+ | [Pat_aux (Pat_exp (P_aux (P_wild,_),exp),_)] -> Some (exp,[])
+ | (Pat_aux (Pat_exp (P_aux (P_typ (_,p),_),exp),ann))::tl ->
+ findpat ((Pat_aux (Pat_exp (p,exp),ann))::tl)
+ | (Pat_aux (Pat_exp (P_aux (P_id id',_),exp),_))::tl
+ when pat_id_is_variable t_env (id_to_string id') ->
+ Some (exp, [(id_to_string id', exp0)])
| (Pat_aux (Pat_exp (P_aux (P_lit (L_aux (lit, _)),_),exp),_))::tl ->
(match bit,lit with
- | (L_zero | L_false), (L_zero | L_false) -> Some exp
- | (L_one | L_true ), (L_one | L_true ) -> Some exp
+ | (L_zero | L_false), (L_zero | L_false) -> Some (exp,[])
+ | (L_one | L_true ), (L_one | L_true ) -> Some (exp,[])
| _ -> findpat tl)
| (Pat_aux (Pat_exp (P_aux (_,(l',_)),_),_))::_ ->
(Reporting_basic.print_err false true l' "Monomorphisation"
@@ -192,6 +475,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
| _ -> None
in
+ (* Similarly, simple conditionals *)
(* TODO: doublecheck *)
let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) =
match l1,l2 with
@@ -215,6 +499,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
| _ -> None
in
+ (* Extract nvar substitution by comparing two types *)
let build_nexp_subst l t1 t2 =
let rec from_types t1 t2 =
let t1 = match t1.t with Tabbrev(_,t) -> t | _ -> t1 in
@@ -260,7 +545,8 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
let rec const_prop_exp substs ((E_aux (e,(l,annot))) as exp) =
let re e = E_aux (e,(l,annot)) in
match e with
- (* TODO: are there circumstances in which we should get rid of these? *)
+ (* TODO: are there more circumstances in which we should get rid of these? *)
+ | E_block [e] -> const_prop_exp substs e
| E_block es -> re (E_block (List.map (const_prop_exp substs) es))
| E_nondet es -> re (E_nondet (List.map (const_prop_exp substs) es))
@@ -276,7 +562,11 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
| E_comment _
-> exp
| E_cast (t,e') -> re (E_cast (t, const_prop_exp substs e'))
- | E_app (id,es) -> re (E_app (id,List.map (const_prop_exp substs) es))
+ | E_app (id,es) ->
+ let es' = List.map (const_prop_exp substs) es in
+ (match const_prop_try_fn (id,es') with
+ | None -> re (E_app (id,es'))
+ | Some r -> r)
| E_app_infix (e1,id,e2) ->
let e1',e2' = const_prop_exp substs e1,const_prop_exp substs e2 in
(match try_app_infix (l,annot) e1' (id_to_string id) e2' with
@@ -311,9 +601,10 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
let e' = const_prop_exp substs e in
(match can_match e' cases with
| None -> re (E_case (e', List.map (const_prop_pexp substs) cases))
- | Some (E_aux (_,(_,annot')) as exp) ->
+ | Some (E_aux (_,(_,annot')) as exp,newbindings) ->
+ let substs' = Envmap.union substs (Envmap.from_list newbindings) in
nexp_substs := build_nexp_subst l annot annot' @ !nexp_substs;
- const_prop_exp substs exp)
+ const_prop_exp substs' exp)
| E_let (lb,e) ->
let (lb',substs') = const_prop_letbind substs lb in
re (E_let (lb', const_prop_exp substs' e))
@@ -357,6 +648,39 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
| LEXP_vector (le,e) -> re (LEXP_vector (const_prop_lexp substs le, const_prop_exp substs e))
| LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (const_prop_lexp substs le, const_prop_exp substs e1, const_prop_exp substs e2))
| LEXP_field (le,id) -> re (LEXP_field (const_prop_lexp substs le, id))
+ (* Reduce a function when
+ 1. all arguments are values,
+ 2. the function is pure,
+ 3. the result is a value
+ (and 4. the function is not scattered, but that's not terribly important)
+ to try and keep execution time and the results managable.
+ *)
+ and const_prop_try_fn (id,args) =
+ if not (List.for_all (is_value t_env) args) then
+ None
+ else
+ let i = id_to_string id in
+ let Defs ds = defs in
+ match list_extract (function
+ | (DEF_fundef (FD_aux (FD_function (_,_,eff,((FCL_aux (FCL_Funcl (id',_,_),_))::_ as fcls)),_)))
+ -> if i = id_to_string id' then Some (eff,fcls) else None
+ | _ -> None) ds with
+ | None -> None
+ | Some (eff,_) when not (is_pure eff) -> None
+ | Some (_,fcls) ->
+ let arg = match args with
+ | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(Unknown,simple_annot unit_t))
+ | [e] -> e
+ | _ -> E_aux (E_tuple args,(Unknown,NoTyp)) in
+ let cases = List.map (function
+ | FCL_aux (FCL_Funcl (_,pat,exp), ann) -> Pat_aux (Pat_exp (pat,exp),ann))
+ fcls in
+ match can_match arg cases with
+ | Some (exp,bindings) ->
+ let substs = Envmap.from_list bindings in
+ let result = const_prop_exp substs exp in
+ if is_value t_env result then Some result else None
+ | None -> None
in
let subst_exp subst exp =
@@ -374,7 +698,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
(* Substitute what we've learned about nvars into the term *)
let nsubsts = Envmap.from_list (List.map (fun (id,ne) -> (id,TA_nexp ne)) !nexp_substs) in
let () = nexp_substs := [] in
- nexp_subst nsubsts exp'
+ nexp_subst_exp t_env nsubsts refinements exp'
in
@@ -428,7 +752,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
Filename.basename p.Lexing.pos_fname = filename &&
p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum) ls
in
-
+
let split_pat var p =
let rec list f = function
| [] -> None
@@ -485,14 +809,45 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
relist spl (fun ps -> P_list ps) ps
in spl p
in
-
- let map_pat (P_aux (_,(l,_)) as p) =
+
+ let map_pat_by_loc (P_aux (p,(l,_)) as pat) =
match match_l l with
| [] -> None
- | [(_,var)] -> split_pat var p
+ | [(_,var)] -> split_pat var pat
| lvs -> raise (Reporting_basic.err_general l
("Multiple variables to split on: " ^ String.concat ", " (List.map snd lvs)))
in
+ let map_pat (P_aux (p,(l,tannot)) as pat) =
+ match map_pat_by_loc pat with
+ | Some l -> VarSplit l
+ | None ->
+ match p with
+ | P_app (id,args) ->
+ (try
+ let i = id_to_string id in
+ let variants = List.assoc i refinements in
+ let constr_tannot =
+ match Envmap.apply t_env i with
+ | Some (Base ((_,{t=Tfn(_,outt,_,_)}),_,_,_,_,_)) -> simple_annot outt
+ | _ -> raise (Reporting_basic.err_general l
+ ("Constructor missing from environment: " ^ i))
+ in
+ let varmap = build_nexp_subst l constr_tannot tannot in
+ let map_inst (insts,i') =
+ let insts = List.map (fun (v,i) ->
+ ((match List.assoc v varmap with
+ | {nexp=Nvar s} -> s
+ | _ -> raise (Reporting_basic.err_general l
+ ("Constructor parameter not a variable: " ^ v))),
+ TA_nexp {nexp=Nconst (Big_int.big_int_of_int i);imp_param=false}))
+ insts in
+ P_aux (P_app (Id_aux (Id i',Generated l),args),(Generated l,tannot)),
+ Envmap.from_list insts
+ in
+ ConstrSplit (List.map map_inst variants)
+ with Not_found -> NoSplit)
+ | _ -> NoSplit
+ in
let check_single_pat (P_aux (_,(l,_)) as p) =
match match_l l with
@@ -560,12 +915,19 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
FE_aux (FE_Fexp (id,map_exp e),annot)
and map_pexp (Pat_aux (Pat_exp (p,e),l)) =
match map_pat p with
- | None -> [Pat_aux (Pat_exp (p,map_exp e),l)]
- | Some patsubsts ->
+ | NoSplit -> [Pat_aux (Pat_exp (p,map_exp e),l)]
+ | VarSplit patsubsts ->
List.map (fun (pat',subst) ->
let exp' = subst_exp subst e in
Pat_aux (Pat_exp (pat', map_exp exp'),l))
patsubsts
+ | ConstrSplit patnsubsts ->
+ List.map (fun (pat',nsubst) ->
+ (* Leave refinements to later *)
+ let pat' = nexp_subst_pat t_env nsubst [] pat' in
+ let exp' = nexp_subst_exp t_env nsubst [] e in
+ Pat_aux (Pat_exp (pat', map_exp exp'),l)
+ ) patnsubsts
and map_letbind (LB_aux (lb,annot)) =
match lb with
| LB_val_explicit (tysch,p,e) -> LB_aux (LB_val_explicit (tysch,check_single_pat p,map_exp e), annot)
@@ -585,12 +947,19 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
let map_funcl (FCL_aux (FCL_Funcl (id,pat,exp),annot)) =
match map_pat pat with
- | None -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)]
- | Some patsubsts ->
+ | NoSplit -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)]
+ | VarSplit patsubsts ->
List.map (fun (pat',subst) ->
let exp' = subst_exp subst exp in
FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot))
patsubsts
+ | ConstrSplit patnsubsts ->
+ List.map (fun (pat',nsubst) ->
+ (* Leave refinements to later *)
+ let pat' = nexp_subst_pat t_env nsubst [] pat' in
+ let exp' = nexp_subst_exp t_env nsubst [] exp in
+ FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot)
+ ) patnsubsts
in
let map_fundef (FD_aux (FD_function (r,t,e,fcls),annot)) =
@@ -617,5 +986,5 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
in
Defs (List.concat (List.map map_def defs))
-
- in map_locs splits defs
+ in
+ map_locs splits defs'