diff options
| author | Brian Campbell | 2017-07-07 11:05:02 +0100 |
|---|---|---|
| committer | Brian Campbell | 2017-07-07 14:07:53 +0100 |
| commit | 10caa78f7d11bae716c714587e059d18cee51476 (patch) | |
| tree | f5b0e200b4e1f53d38e2eded87dd3ab6541f7152 /src | |
| parent | 9cb879efde58abfd5cc4ae8b2d0344902c983cde (diff) | |
Implement basic monomorphisation of constructors
Diffstat (limited to 'src')
| -rw-r--r-- | src/monomorphise.ml | 319 |
1 files changed, 302 insertions, 17 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 14ea30ba..2e78b6b5 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -2,6 +2,8 @@ open Parse_ast open Ast open Type_internal +let disable_const_propagation = ref false + (* TODO: put this somewhere common *) let id_to_string (Id_aux(id,l)) = @@ -16,7 +18,67 @@ let optmap v f = | None -> None | Some v -> Some (f v) -let disable_const_propagation = ref false +(* We need to be able to substitute in source types to monomorphise + datatype constructors *) + +let rec nexp_to_snexp ne = + let mkne n = Nexp_aux (n,Unknown) in + match ne.nexp with + | Nvar s -> mkne (Nexp_var (Kid_aux (Var s,Unknown))) + | Nid (s,ne) -> mkne (Nexp_id (Id_aux (Id s,Unknown))) + | Nconst i -> mkne (Nexp_constant (Big_int.int_of_big_int i)) + | Npos_inf + | Nneg_inf -> + raise (Reporting_basic.err_general Unknown ("Can't translate inf nexps back into source nexps")) + | Nadd (n1,n2) -> mkne (Nexp_sum (nexp_to_snexp n1, nexp_to_snexp n2)) + | Nsub (n1,n2) -> mkne (Nexp_minus (nexp_to_snexp n1, nexp_to_snexp n2)) + | Nmult (n1,n2) -> mkne (Nexp_times (nexp_to_snexp n1, nexp_to_snexp n2)) + | N2n (ne,_) -> mkne (Nexp_exp (nexp_to_snexp ne)) + | Npow _ -> + raise (Reporting_basic.err_general Unknown ("Can't translate pow nexps back into source nexps")) + | Nneg ne -> mkne (Nexp_neg (nexp_to_snexp ne)) + | Ninexact -> + raise (Reporting_basic.err_general Unknown ("Can't translate inexact nexps back into source nexps")) + | Nuvar _ -> + raise (Reporting_basic.err_general Unknown ("Can't translate uvar nexps back into source nexps")) + + +let subst_src_typ substs t = + let rec s_snexp (Nexp_aux (ne,l) as nexp) = + let re ne = Nexp_aux (ne,l) in + match ne with + | Nexp_var (Kid_aux (Var i,l)) -> + (match Envmap.apply substs i with + | Some (TA_nexp ne) -> nexp_to_snexp ne + | _ -> re (Nexp_var (Kid_aux(Var i,Generated l)))) + | Nexp_id _ + | Nexp_constant _ -> nexp + | Nexp_times (n1,n2) -> re (Nexp_times (s_snexp n1, s_snexp n2)) + | Nexp_sum (n1,n2) -> re (Nexp_sum (s_snexp n1, s_snexp n2)) + | Nexp_minus (n1,n2) -> re (Nexp_minus (s_snexp n1, s_snexp n2)) + | Nexp_exp ne -> re (Nexp_exp (s_snexp ne)) + | Nexp_neg ne -> re (Nexp_neg (s_snexp ne)) + in + let rec s_styp ((Typ_aux (t,l)) as ty) = + let re t = Typ_aux (t,l) in + match t with + | Typ_wild + | Typ_id _ + | Typ_var _ + -> ty + | Typ_fn (t1,t2,e) -> re (Typ_fn (s_styp t1, s_styp t2,e)) + | Typ_tup ts -> re (Typ_tup (List.map s_styp ts)) + | Typ_app (id,tas) -> re (Typ_app (id,List.map s_starg tas)) + and s_starg (Typ_arg_aux (ta,l) as targ) = + match ta with + | Typ_arg_nexp ne -> Typ_arg_aux (Typ_arg_nexp (s_snexp ne),l) + | Typ_arg_typ t -> Typ_arg_aux (Typ_arg_typ (s_styp t),l) + | Typ_arg_order _ + | Typ_arg_effect _ -> targ + in s_styp t + + + (* Based on current type checker's behaviour *) let pat_id_is_variable t_env id = @@ -25,9 +87,145 @@ let pat_id_is_variable t_env id = | Some (Base(_,Enum _,_,_,_,_)) -> false | _ -> true - -let nexp_subst substs exp = +let rec list_extract f = function + | [] -> None + | h::t -> match f h with None -> list_extract f t | Some v -> Some v + +let rec cross = function + | [] -> failwith "cross" + | [(x,l)] -> List.map (fun y -> [(x,y)]) l + | (x,l)::t -> + let t' = cross t in + List.concat (List.map (fun y -> List.map (fun l' -> (x,y)::l') t') l) + +(* Given a type for a constructor, work out which refinements we ought to produce *) +(* TODO collision avoidance *) +let split_src_type i ty (TypQ_aux (q,ql)) = + let rec size_nvars_nexp (Nexp_aux (ne,_)) = + match ne with + | Nexp_var (Kid_aux (Var v,_)) -> [v] + | Nexp_id _ + | Nexp_constant _ + -> [] + | Nexp_times (n1,n2) + | Nexp_sum (n1,n2) + | Nexp_minus (n1,n2) + -> size_nvars_nexp n1 @ size_nvars_nexp n2 + | Nexp_exp n + | Nexp_neg n + -> size_nvars_nexp n + in + let rec size_nvars_ty (Typ_aux (ty,l)) = + match ty with + | Typ_wild + | Typ_id _ + | Typ_var _ + -> [] + | Typ_fn _ -> + raise (Reporting_basic.err_general l ("Function type in constructor " ^ i)) + | Typ_tup ts -> List.concat (List.map size_nvars_ty ts) + | Typ_app (Id_aux (Id "vector",_), + [_;Typ_arg_aux (Typ_arg_nexp sz,_); + _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) -> + size_nvars_nexp sz + | Typ_app (_, tas) -> + [] (* We only support sizes for bitvectors mentioned explicitly, not any buried + inside another type *) + in + let nvars = List.sort_uniq (String.compare) (size_nvars_ty ty) in + match nvars with + | [] -> None + | sample::__ -> + (* Only check for constraints if we found a size to constrain *) + let qs = + match q with + | TypQ_no_forall -> + raise (Reporting_basic.err_general ql + ("No set constraint for variable " ^ sample ^ " in constructor " ^ i)) + | TypQ_tq qs -> qs + in + let find_set nvar = + match list_extract (function + | QI_aux (QI_const (NC_aux (NC_nat_set_bounded (Kid_aux (Var nvar',_),vals),_)),_) + -> if nvar = nvar' then Some vals else None + | _ -> None) qs with + | None -> + raise (Reporting_basic.err_general ql + ("No set constraint for variable " ^ nvar ^ " in constructor " ^ i)) + | Some vals -> (nvar,vals) + in + let nvar_sets = List.map find_set nvars in + let total_variants = List.fold_left ( * ) 1 (List.map (fun (_,l) -> List.length l) nvar_sets) in + let limit = 8 in + let () = if total_variants > limit then + raise (Reporting_basic.err_general ql + (string_of_int total_variants ^ "variants for constructor " ^ i ^ + "bigger than limit " ^ string_of_int limit)) else () + in + let variants = cross nvar_sets in + let name l = String.concat "_" (i::(List.map (fun (v,i) -> v ^ string_of_int i) l)) in + Some (List.map (fun l -> (l, name l)) variants) + +(* TODO: maybe fold this into subst_src_typ? *) +let inst_src_type insts ty = + let insts = List.map (fun (v,i) -> (v,TA_nexp {nexp=Nconst (Big_int.big_int_of_int i); imp_param=false})) insts in + let subst = Envmap.from_list insts in + subst_src_typ subst ty + +let reduce_nexp subst ne = + let ne = subst_n_with_env subst ne in + let rec eval ne = + match ne.nexp with + | Nconst i -> i + | Nadd (n1,n2) -> Big_int.add_big_int (eval n1) (eval n2) + | Nsub (n1,n2) -> Big_int.sub_big_int (eval n1) (eval n2) + | Nmult (n1,n2) -> Big_int.mult_big_int (eval n1) (eval n2) + | N2n (n,_) -> Big_int.power_int_positive_big_int 2 (eval n) + | Npow (n,i) -> Big_int.power_big_int_positive_int (eval n) i + | Nneg n -> Big_int.minus_big_int (eval n) + | _ -> + raise (Reporting_basic.err_general Unknown ("Couldn't turn nexp " ^ + n_to_string ne ^ " into concrete value")) + in Big_int.int_of_big_int (eval ne) + +(* Check to see if we need to monomorphise a use of a constructor. Currently + assumes that bitvector sizes are always given as a variable; don't yet handle + more general cases (e.g., 8 * var) *) + +let refine_constructor refinements i substs arg t = + let rec derive_vars t (E_aux (e,(l,tannot)) as exp) = + match t.t with + | Tapp ("vector", [_;TA_nexp {nexp = Nvar v};_;TA_typ {t=Tid "bit"}]) -> + (match tannot with + | Base ((_,{t=Tapp ("vector", [_;TA_nexp ne;_;TA_typ {t=Tid "bit"}])}),_,_,_,_,_) -> + [(v,reduce_nexp substs ne)] + | _ -> []) + | Tvar _ + | Tid _ + | Tfn _ + | Tapp _ + | Tuvar _ + -> [] + | Tabbrev (_,t) + | Toptions (t,_) + -> derive_vars t exp + | Ttup ts -> + match e with + | E_tuple es -> List.concat (List.map2 derive_vars ts es) + | _ -> [] (* TODO? *) + in + try + let irefinements = List.assoc i refinements in + let vars = List.sort_uniq (fun x y -> String.compare (fst x) (fst y)) (derive_vars t arg) in + Some (List.assoc vars irefinements) + with Not_found -> None + + +(* Substitute found nexps for variables in an expression, and rename constructors to reflect + specialisation *) + +let nexp_subst_fns t_env substs refinements = let s_t t = typ_subst substs true t in (* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in hopefully don't need this anyway *) @@ -68,7 +266,23 @@ let nexp_subst substs exp = | E_internal_exp_user ((l1,annot1),(l2,annot2)) -> re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2))) | E_cast (t,e') -> re (E_cast (t, s_exp e')) - | E_app (id,es) -> re (E_app (id,List.map s_exp es)) + | E_app (id,es) -> + let es' = List.map s_exp es in + let arg = + match es' with + | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(Unknown,simple_annot unit_t)) + | [e] -> e + | _ -> E_aux (E_tuple es',(Unknown,NoTyp)) + in + let i = id_to_string id in + let id' = + match Envmap.apply t_env i with + | Some (Base((t_params,{t=Tfn(inty,outty,_,_)}),Constructor _,_,_,_,_)) -> + (match refine_constructor refinements i substs arg inty with + | None -> id + | Some i -> Id_aux (Id i,Generated l)) + | _ -> id + in re (E_app (id',es')) | E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2)) | E_tuple es -> re (E_tuple (List.map s_exp es)) | E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3)) @@ -123,7 +337,9 @@ let nexp_subst substs exp = | LEXP_vector (le,e) -> re (LEXP_vector (s_lexp le, s_exp e)) | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2)) | LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id)) - in s_exp exp + in (s_pat,s_exp) +let nexp_subst_pat t_env substs refinements = fst (nexp_subst_fns t_env substs refinements) +let nexp_subst_exp t_env substs refinements = snd (nexp_subst_fns t_env substs refinements) let bindings_from_pat t_env p = let rec aux_pat (P_aux (p,annot)) = @@ -151,10 +367,47 @@ let remove_bound t_env env pat = let bound = bindings_from_pat t_env pat in List.fold_left (fun sub v -> Envmap.remove env v) env bound +(* We may need to split up a pattern match if (1) we've been told to case split + on a variable by the user, or (2) we monomorphised a constructor that's used + in the pattern. *) +type split = + | NoSplit + | VarSplit of (tannot pat * (string * tannot Ast.exp)) list + | ConstrSplit of (tannot pat * t_arg Envmap.t) list let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = + let split_constructors (Defs defs) = + let sc_type_union q (Tu_aux (tu,l) as tua) = + match tu with + | Tu_id id -> [],[tua] + | Tu_ty_id (ty,id) -> + let i = id_to_string id in + (match split_src_type i ty q with + | None -> ([],[Tu_aux (Tu_ty_id (ty,id),l)]) + | Some variants -> + ([(i,variants)], + List.map (fun (insts, i') -> Tu_aux (Tu_ty_id (inst_src_type insts ty,Id_aux (Id i',Generated l)),Generated l)) variants)) + in + let sc_type_def ((TD_aux (tda,annot)) as td) = + match tda with + | TD_variant (id,nscm,quant,tus,flag) -> + let (refinements, tus') = List.split (List.map (sc_type_union quant) tus) in + (List.concat refinements, TD_aux (TD_variant (id,nscm,quant,List.concat tus',flag),annot)) + | _ -> ([],td) + in + let sc_def d = + match d with + | DEF_type td -> let (refinements,td') = sc_type_def td in (refinements, DEF_type td') + | _ -> ([], d) + in + let (refinements, defs') = List.split (List.map sc_def defs) + in (List.concat refinements, Defs defs') + in + + let (refinements, defs') = split_constructors defs in - let can_match (E_aux (e,(l,annot)) as exp) cases = + (* Attempt simple pattern matches *) + let can_match (E_aux (e,(l,annot))) cases = match e with | E_id id -> let i = id_to_string id in @@ -192,6 +445,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = | _ -> None in + (* Similarly, simple conditionals *) (* TODO: doublecheck *) let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) = match l1,l2 with @@ -215,6 +469,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = | _ -> None in + (* Extract nvar substitution by comparing two types *) let build_nexp_subst l t1 t2 = let rec from_types t1 t2 = let t1 = match t1.t with Tabbrev(_,t) -> t | _ -> t1 in @@ -374,7 +629,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = (* Substitute what we've learned about nvars into the term *) let nsubsts = Envmap.from_list (List.map (fun (id,ne) -> (id,TA_nexp ne)) !nexp_substs) in let () = nexp_substs := [] in - nexp_subst nsubsts exp' + nexp_subst_exp t_env nsubsts refinements exp' in @@ -428,7 +683,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = Filename.basename p.Lexing.pos_fname = filename && p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum) ls in - + let split_pat var p = let rec list f = function | [] -> None @@ -485,14 +740,30 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = relist spl (fun ps -> P_list ps) ps in spl p in - - let map_pat (P_aux (_,(l,_)) as p) = + + let map_pat_by_loc (P_aux (p,(l,_)) as pat) = match match_l l with | [] -> None - | [(_,var)] -> split_pat var p + | [(_,var)] -> split_pat var pat | lvs -> raise (Reporting_basic.err_general l ("Multiple variables to split on: " ^ String.concat ", " (List.map snd lvs))) in + let map_pat (P_aux (p,(l,tannot)) as pat) = + match map_pat_by_loc pat with + | Some l -> VarSplit l + | None -> + match p with + | P_app (id,args) -> + (try + let variants = List.assoc (id_to_string id) refinements in + let map_inst (insts,i') = + P_aux (P_app (Id_aux (Id i',Generated l),args),(Generated l,tannot)), + Envmap.from_list (List.map (fun (v,i) -> (v,TA_nexp {nexp=Nconst (Big_int.big_int_of_int i);imp_param=false})) insts) + in + ConstrSplit (List.map map_inst variants) + with Not_found -> NoSplit) + | _ -> NoSplit + in let check_single_pat (P_aux (_,(l,_)) as p) = match match_l l with @@ -560,12 +831,19 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = FE_aux (FE_Fexp (id,map_exp e),annot) and map_pexp (Pat_aux (Pat_exp (p,e),l)) = match map_pat p with - | None -> [Pat_aux (Pat_exp (p,map_exp e),l)] - | Some patsubsts -> + | NoSplit -> [Pat_aux (Pat_exp (p,map_exp e),l)] + | VarSplit patsubsts -> List.map (fun (pat',subst) -> let exp' = subst_exp subst e in Pat_aux (Pat_exp (pat', map_exp exp'),l)) patsubsts + | ConstrSplit patnsubsts -> + List.map (fun (pat',nsubst) -> + (* Leave refinements to later *) + let pat' = nexp_subst_pat t_env nsubst [] pat' in + let exp' = nexp_subst_exp t_env nsubst [] e in + Pat_aux (Pat_exp (pat', map_exp exp'),l) + ) patnsubsts and map_letbind (LB_aux (lb,annot)) = match lb with | LB_val_explicit (tysch,p,e) -> LB_aux (LB_val_explicit (tysch,check_single_pat p,map_exp e), annot) @@ -585,12 +863,19 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = let map_funcl (FCL_aux (FCL_Funcl (id,pat,exp),annot)) = match map_pat pat with - | None -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)] - | Some patsubsts -> + | NoSplit -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)] + | VarSplit patsubsts -> List.map (fun (pat',subst) -> let exp' = subst_exp subst exp in FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot)) patsubsts + | ConstrSplit patnsubsts -> + List.map (fun (pat',nsubst) -> + (* Leave refinements to later *) + let pat' = nexp_subst_pat t_env nsubst [] pat' in + let exp' = nexp_subst_exp t_env nsubst [] exp in + FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot) + ) patnsubsts in let map_fundef (FD_aux (FD_function (r,t,e,fcls),annot)) = @@ -617,5 +902,5 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs = in Defs (List.concat (List.map map_def defs)) - - in map_locs splits defs + in + map_locs splits defs' |
