diff options
Diffstat (limited to 'src/monomorphise.ml')
| -rw-r--r-- | src/monomorphise.ml | 311 |
1 files changed, 234 insertions, 77 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index d9ee73b8..cc68fbe3 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -28,8 +28,9 @@ let isubst_union s1 s2 = | _, _ -> None) s1 s2 let subst_src_typ substs t = - let rec s_snexp (Nexp_aux (ne,l) as nexp) = + let rec s_snexp substs (Nexp_aux (ne,l) as nexp) = let re ne = Nexp_aux (ne,l) in + let s_snexp = s_snexp substs in match ne with | Nexp_var (Kid_aux (_,l) as kid) -> (try KSubst.find kid substs @@ -42,22 +43,25 @@ let subst_src_typ substs t = | Nexp_exp ne -> re (Nexp_exp (s_snexp ne)) | Nexp_neg ne -> re (Nexp_neg (s_snexp ne)) in - let rec s_styp ((Typ_aux (t,l)) as ty) = + let rec s_styp substs ((Typ_aux (t,l)) as ty) = let re t = Typ_aux (t,l) in match t with | Typ_wild | Typ_id _ | Typ_var _ -> ty - | Typ_fn (t1,t2,e) -> re (Typ_fn (s_styp t1, s_styp t2,e)) - | Typ_tup ts -> re (Typ_tup (List.map s_styp ts)) - | Typ_app (id,tas) -> re (Typ_app (id,List.map s_starg tas)) - and s_starg (Typ_arg_aux (ta,l) as targ) = + | Typ_fn (t1,t2,e) -> re (Typ_fn (s_styp substs t1, s_styp substs t2,e)) + | Typ_tup ts -> re (Typ_tup (List.map (s_styp substs) ts)) + | Typ_app (id,tas) -> re (Typ_app (id,List.map (s_starg substs) tas)) + | Typ_exist (kids,nc,t) -> + let substs = List.fold_left (fun sub v -> KSubst.remove v sub) substs kids in + re (Typ_exist (kids,nc,s_styp substs t)) + and s_starg substs (Typ_arg_aux (ta,l) as targ) = match ta with - | Typ_arg_nexp ne -> Typ_arg_aux (Typ_arg_nexp (s_snexp ne),l) - | Typ_arg_typ t -> Typ_arg_aux (Typ_arg_typ (s_styp t),l) + | Typ_arg_nexp ne -> Typ_arg_aux (Typ_arg_nexp (s_snexp substs ne),l) + | Typ_arg_typ t -> Typ_arg_aux (Typ_arg_typ (s_styp substs t),l) | Typ_arg_order _ -> targ - in s_styp t + in s_styp substs t let make_vector_lit sz i = let f j = if (i lsr (sz-j-1)) mod 2 = 0 then '0' else '1' in @@ -128,6 +132,111 @@ let rec cross = function let t' = cross t in List.concat (List.map (fun y -> List.map (fun l' -> (x,y)::l') t') l) +let rec cross' = function + | [] -> [[]] + | (h::t) -> + let t' = cross' t in + List.concat (List.map (fun x -> List.map (List.cons x) t') h) + +let rec cross'' = function + | [] -> [[]] + | (k,None)::t -> List.map (List.cons (k,None)) (cross'' t) + | (k,Some h)::t -> + let t' = cross'' t in + List.concat (List.map (fun x -> List.map (List.cons (k,Some x)) t') h) + +let kidset_bigunion = function + | [] -> KidSet.empty + | h::t -> List.fold_left KidSet.union h t + +(* TODO: deal with non-set constraints, intersections, etc somehow *) +let extract_set_nc var (NC_aux (_,l) as nc) = + let rec aux (NC_aux (nc,l)) = + let re nc = NC_aux (nc,l) in + match nc with + | NC_nat_set_bounded (id,is) when Kid.compare id var = 0 -> Some (is,re NC_true) + | NC_and (nc1,nc2) -> + (match aux nc1, aux nc2 with + | None, None -> None + | None, Some (is,nc2') -> Some (is, re (NC_and (nc1,nc2'))) + | Some (is,nc1'), None -> Some (is, re (NC_and (nc1',nc2))) + | Some _, Some _ -> + raise (Reporting_basic.err_general l ("Multiple set constraints for " ^ string_of_kid var))) + | _ -> None + in match aux nc with + | Some is -> is + | None -> + raise (Reporting_basic.err_general l ("No set constraint for " ^ string_of_kid var)) + +let rec peel = function + | [], l -> ([], l) + | h1::t1, h2::t2 -> let (l1,l2) = peel (t1, t2) in ((h1,h2)::l1,l2) + | _,_ -> assert false + +let rec split_insts = function + | [] -> [],[] + | (k,None)::t -> let l1,l2 = split_insts t in l1,k::l2 + | (k,Some v)::t -> let l1,l2 = split_insts t in (k,v)::l1,l2 + +let apply_kid_insts kid_insts t = + let kid_insts, kids' = split_insts kid_insts in + let kid_insts = List.map (fun (v,i) -> (v,Nexp_aux (Nexp_constant i,Generated Unknown))) kid_insts in + let subst = ksubst_from_list kid_insts in + kids', subst_src_typ subst t + +let rec inst_src_type insts (Typ_aux (ty,l) as typ) = + match ty with + | Typ_wild + | Typ_id _ + | Typ_var _ + -> insts,typ + | Typ_fn _ -> + raise (Reporting_basic.err_general l "Function type in constructor") + | Typ_tup ts -> + let insts,ts = + List.fold_right + (fun typ (insts,ts) -> let insts,typ = inst_src_type insts typ in insts,typ::ts) + ts (insts,[]) + in insts, Typ_aux (Typ_tup ts,l) + | Typ_app (id,args) -> + let insts,ts = + List.fold_right + (fun arg (insts,args) -> let insts,arg = inst_src_typ_arg insts arg in insts,arg::args) + args (insts,[]) + in insts, Typ_aux (Typ_app (id,ts),l) + | Typ_exist (kids, nc, t) -> + let kid_insts, insts' = peel (kids,insts) in + let kids', t' = apply_kid_insts kid_insts t in + (* TODO: subst in nc *) + match kids' with + | [] -> insts', t' + | _ -> insts', Typ_aux (Typ_exist (kids', nc, t'), l) +and inst_src_typ_arg insts (Typ_arg_aux (ta,l) as tyarg) = + match ta with + | Typ_arg_nexp _ + | Typ_arg_order _ + -> insts, tyarg + | Typ_arg_typ typ -> + let insts', typ' = inst_src_type insts typ in + insts', Typ_arg_aux (Typ_arg_typ typ',l) + +let rec contains_exist (Typ_aux (ty,_)) = + match ty with + | Typ_wild + | Typ_id _ + | Typ_var _ + -> false + | Typ_fn (t1,t2,_) -> contains_exist t1 || contains_exist t2 + | Typ_tup ts -> List.exists contains_exist ts + | Typ_app (_,args) -> List.exists contains_exist_arg args + | Typ_exist _ -> true +and contains_exist_arg (Typ_arg_aux (arg,_)) = + match arg with + | Typ_arg_nexp _ + | Typ_arg_order _ + -> false + | Typ_arg_typ typ -> contains_exist typ + (* Given a type for a constructor, work out which refinements we ought to produce *) (* TODO collision avoidance *) let split_src_type id ty (TypQ_aux (q,ql)) = @@ -146,65 +255,87 @@ let split_src_type id ty (TypQ_aux (q,ql)) = | Nexp_neg n -> size_nvars_nexp n in - let rec size_nvars_ty (Typ_aux (ty,l)) = + (* This was originally written for the general case, but I cut it down to the + more manageable prenex-form below *) + let rec size_nvars_ty (Typ_aux (ty,l) as typ) = match ty with | Typ_wild | Typ_id _ | Typ_var _ - -> [] + -> (KidSet.empty,[[],typ]) | Typ_fn _ -> raise (Reporting_basic.err_general l ("Function type in constructor " ^ i)) - | Typ_tup ts -> List.concat (List.map size_nvars_ty ts) + | Typ_tup ts -> + let (vars,tys) = List.split (List.map size_nvars_ty ts) in + let insttys = List.map (fun x -> let (insts,tys) = List.split x in + List.concat insts, Typ_aux (Typ_tup tys,l)) (cross' tys) in + (kidset_bigunion vars, insttys) | Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp sz,_); _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) -> - size_nvars_nexp sz + (KidSet.of_list (size_nvars_nexp sz), [[],typ]) | Typ_app (_, tas) -> - [] (* We only support sizes for bitvectors mentioned explicitly, not any buried - inside another type *) + (KidSet.empty,[[],typ]) (* We only support sizes for bitvectors mentioned explicitly, not any buried + inside another type *) + | Typ_exist (kids, nc, t) -> + let (vars,tys) = size_nvars_ty t in + let find_insts k (insts,nc) = + let inst,nc' = + if KidSet.mem k vars then + let is,nc' = extract_set_nc k nc in + Some is,nc' + else None,nc + in (k,inst)::insts,nc' + in + let (insts,nc') = List.fold_right find_insts kids ([],nc) in + let insts = cross'' insts in + let ty_and_inst (inst0,ty) inst = + let kids, ty = apply_kid_insts inst ty in + let ty = Typ_aux (Typ_exist (kids, nc', ty),l) in + inst@inst0, ty + in + let tys = List.concat (List.map (fun instty -> List.map (ty_and_inst instty) insts) tys) in + let free = List.fold_left (fun vars k -> KidSet.remove k vars) vars kids in + (free,tys) + in + (* Only single-variable prenex-form for now *) + let size_nvars_ty (Typ_aux (ty,l) as typ) = + match ty with + | Typ_exist (kids,_,t) -> + begin + match snd (size_nvars_ty typ) with + | [] -> [] + | tys -> + if contains_exist t then + raise (Reporting_basic.err_general l + "Only prenex types in unions are supported by monomorphisation") + else if List.length kids > 1 then + raise (Reporting_basic.err_general l + "Only single-variable existential types in unions are currently supported by monomorphisation") + else tys + end + | _ -> [] in - let nvars = List.sort_uniq Kid.compare (size_nvars_ty ty) in - match nvars with + (* TODO: reject universally quantification or monomorphise it *) + let variants = size_nvars_ty ty in + match variants with | [] -> None | sample::__ -> - (* Only check for constraints if we found a size to constrain *) - let qs = - match q with - | TypQ_no_forall -> - raise (Reporting_basic.err_general ql - ("No set constraint for variable " ^ string_of_kid sample ^ " in constructor " ^ i)) - | TypQ_tq qs -> qs - in - let find_set (Kid_aux (Var nvar,_) as kid) = - match list_extract (function - | QI_aux (QI_const (NC_aux (NC_nat_set_bounded (Kid_aux (Var nvar',_),vals),_)),_) - -> if nvar = nvar' then Some vals else None - | _ -> None) qs with - | None -> - raise (Reporting_basic.err_general ql - ("No set constraint for variable " ^ nvar ^ " in constructor " ^ i)) - | Some vals -> (kid,vals) - in - let nvar_sets = List.map find_set nvars in - let total_variants = List.fold_left ( * ) 1 (List.map (fun (_,l) -> List.length l) nvar_sets) in - let () = if total_variants > size_set_limit then + let () = if List.length variants > size_set_limit then raise (Reporting_basic.err_general ql - (string_of_int total_variants ^ "variants for constructor " ^ i ^ + (string_of_int (List.length variants) ^ "variants for constructor " ^ i ^ "bigger than limit " ^ string_of_int size_set_limit)) else () in - let variants = cross nvar_sets in let wrap = match id with | Id_aux (Id i,l) -> (fun f -> Id_aux (Id (f i),Generated l)) | Id_aux (DeIid i,l) -> (fun f -> Id_aux (DeIid (f i),l)) in - let name l i = String.concat "_" (i::(List.map (fun (v,i) -> string_of_kid v ^ string_of_int i) l)) in - Some (List.map (fun l -> (l, wrap (name l))) variants) - -(* TODO: maybe fold this into subst_src_typ? *) -let inst_src_type insts ty = - let insts = List.map (fun (v,i) -> (v,Nexp_aux (Nexp_constant i,Generated Unknown))) insts in - let subst = ksubst_from_list insts in - subst_src_typ subst ty + let name_seg = function + | (_,None) -> "" + | (k,Some i) -> string_of_kid k ^ string_of_int i + in + let name l i = String.concat "_" (i::(List.map name_seg l)) in + Some (List.map (fun (l,ty) -> (l, wrap (name l),ty)) variants) let reduce_nexp subst ne = let rec eval (Nexp_aux (ne,_) as nexp) = @@ -394,6 +525,7 @@ let bindings_from_pat p = | P_typ (_,p) -> aux_pat p | P_id id -> if pat_id_is_variable env id then [id] else [] + | P_var kid -> [id_of_kid kid] | P_vector ps | P_vector_concat ps | P_app (_,ps) @@ -408,7 +540,7 @@ let bindings_from_pat p = let remove_bound env pat = let bound = bindings_from_pat pat in - List.fold_left (fun sub v -> ISubst.remove v env) env bound + List.fold_left (fun sub v -> ISubst.remove v sub) env bound (* Remove explicit existential types from the AST, so that the sizes of bitvectors will be filled in throughout. @@ -461,6 +593,8 @@ let rec deexist_exp (E_aux (e,(l,(annot : Type_check.tannot))) as exp) = | E_let (lb,e1) -> re (E_let (deexist_letbind lb, deexist_exp e1)) | E_assign (le,e1) -> re (E_assign (deexist_lexp le, deexist_exp e1)) | E_exit e1 -> re (E_exit (deexist_exp e1)) + | E_throw e1 -> re (E_throw (deexist_exp e1)) + | E_try (e1,cases) -> re (E_try (deexist_exp e1, List.map deexist_pexp cases)) | E_return e1 -> re (E_return (deexist_exp e1)) | E_assert (e1,e2) -> re (E_assert (deexist_exp e1,deexist_exp e2)) and deexist_pexp (Pat_aux (pe,(l,annot))) = @@ -548,6 +682,13 @@ let can_match (E_aux (e,(l,annot)) as exp0) cases = in findpat_generic checkpat "bit" cases | _ -> None +(* Remove top-level casts from an expression. Useful when we need to look at + subexpressions to reduce something, but could break type-checking if we used + it everywhere. *) +let rec drop_casts = function + | E_aux (E_cast (_,e),_) -> drop_casts e + | exp -> exp + (* Similarly, simple conditionals *) let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) = @@ -566,8 +707,8 @@ let neq_fns = [Id "neq_anything"] let try_app (l,ann) (Id_aux (id,_),args) = let is_eq = List.mem id eq_fns in let is_neq = (not is_eq) && List.mem id neq_fns in + let new_l = Generated l in if is_eq || is_neq then - let new_l = Generated l in match args with | [E_aux (E_lit l1,_); E_aux (E_lit l2,_)] -> let lit b = if b then L_true else L_false in @@ -576,6 +717,11 @@ let try_app (l,ann) (Id_aux (id,_),args) = | None -> None | Some b -> Some (E_aux (E_lit (L_aux (lit b,new_l)),(l,ann)))) | _ -> None + else if id = Id "cast_bit_bool" then + match args with + | [E_aux (E_lit L_aux (L_zero,_),_)] -> Some (E_aux (E_lit (L_aux (L_false,new_l)),(l,ann))) + | [E_aux (E_lit L_aux (L_one ,_),_)] -> Some (E_aux (E_lit (L_aux (L_true ,new_l)),(l,ann))) + | _ -> None else None @@ -591,7 +737,6 @@ let try_app_infix (l,ann) (E_aux (e1,ann1)) (Id_aux (id,_)) (E_aux (e2,ann2)) = | None -> None) | _ -> None - (* We may need to split up a pattern match if (1) we've been told to case split on a variable by the user, or (2) we monomorphised a constructor that's used in the pattern. *) @@ -610,7 +755,7 @@ let split_defs splits defs = | None -> ([],[Tu_aux (Tu_ty_id (ty,id),l)]) | Some variants -> ([(id,variants)], - List.map (fun (insts, id') -> Tu_aux (Tu_ty_id (inst_src_type insts ty,id'),Generated l)) variants)) + List.map (fun (insts, id', ty) -> Tu_aux (Tu_ty_id (ty,id'),Generated l)) variants)) in let sc_type_def ((TD_aux (tda,annot)) as td) = match tda with @@ -710,7 +855,7 @@ let split_defs splits defs = | E_if (e1,e2,e3) -> let e1' = const_prop_exp substs e1 in let e2',e3' = const_prop_exp substs e2, const_prop_exp substs e3 in - (match e1' with + (match drop_casts e1' with | E_aux (E_lit (L_aux ((L_true|L_false) as lit ,_)),_) -> let e' = match lit with L_true -> e2' | _ -> e3' in (match e' with E_aux (_,(_,annot')) -> @@ -745,6 +890,10 @@ let split_defs splits defs = re (E_let (lb', const_prop_exp substs' e)) | E_assign (le,e) -> re (E_assign (const_prop_lexp substs le, const_prop_exp substs e)) | E_exit e -> re (E_exit (const_prop_exp substs e)) + | E_throw e -> re (E_throw (const_prop_exp substs e)) + | E_try (e,cases) -> + let e' = const_prop_exp substs e in + re (E_case (e', List.map (const_prop_pexp substs) cases)) | E_return e -> re (E_return (const_prop_exp substs e)) | E_assert (e1,e2) -> re (E_assert (const_prop_exp substs e1,const_prop_exp substs e2)) | E_internal_cast (ann,e) -> re (E_internal_cast (ann,const_prop_exp substs e)) @@ -836,7 +985,7 @@ let split_defs splits defs = (* Substitute what we've learned about nvars into the term *) let nsubsts = isubst_from_list !nexp_substs in let () = nexp_substs := [] in - nexp_subst_exp nsubsts refinements exp' + (*nexp_subst_exp nsubsts refinements*) exp' in (* Split a variable pattern into every possible value *) @@ -933,6 +1082,7 @@ let split_defs splits defs = match p with | P_lit _ | P_wild + | P_var _ -> None | P_as (p',id) when id_matches id -> raise (Reporting_basic.err_general l @@ -978,29 +1128,34 @@ let split_defs splits defs = | None -> match p with | P_app (id,args) -> - (try - let (_,variants) = List.find (fun (id',_) -> Id.compare id id' = 0) refinements in - let env,_ = env_typ_expected l tannot in - let constr_out_typ = - match Env.get_val_spec id env with - | (qs,Typ_aux (Typ_fn(_,outt,_),_)) -> outt - | _ -> raise (Reporting_basic.err_general l - ("Constructor " ^ string_of_id id ^ " is not a construtor!")) - in - let varmap = build_nexp_subst l constr_out_typ tannot in - let map_inst (insts,id') = - let insts = List.map (fun (v,i) -> - ((match List.assoc (string_of_kid v) varmap with - | Nexp_aux (Nexp_var s, _) -> s - | _ -> raise (Reporting_basic.err_general l - ("Constructor parameter not a variable: " ^ string_of_kid v))), - Nexp_aux (Nexp_constant i,Generated l))) - insts in - P_aux (P_app (id',args),(Generated l,tannot)), - ksubst_from_list insts - in - ConstrSplit (List.map map_inst variants) - with Not_found -> NoSplit) + begin + let kid,kid_annot = + match args with + | [P_aux (P_var kid,ann)] -> kid,ann + | _ -> + raise (Reporting_basic.err_general l + "Pattern match not currently supported by monomorphisation") + in match List.find (fun (id',_) -> Id.compare id id' = 0) refinements with + | (_,variants) -> + let map_inst (insts,id',_) = + let insts = + match insts with [(v,Some i)] -> [(kid,Nexp_aux (Nexp_constant i, Generated l))] + | _ -> assert false + in +(* + let insts,_ = split_insts insts in + let insts = List.map (fun (v,i) -> + (??, + Nexp_aux (Nexp_constant i,Generated l))) + insts in + P_aux (P_app (id',args),(Generated l,tannot)), +*) + P_aux (P_app (id',[P_aux (P_id (id_of_kid kid),kid_annot)]),(Generated l,tannot)), + ksubst_from_list insts + in + ConstrSplit (List.map map_inst variants) + | exception Not_found -> NoSplit + end | _ -> NoSplit in @@ -1055,6 +1210,8 @@ let split_defs splits defs = | E_let (lb,e) -> re (E_let (map_letbind lb, map_exp e)) | E_assign (le,e) -> re (E_assign (map_lexp le, map_exp e)) | E_exit e -> re (E_exit (map_exp e)) + | E_throw e -> re (E_throw e) + | E_try (e,cases) -> re (E_try (map_exp e, List.concat (List.map map_pexp cases))) | E_return e -> re (E_return (map_exp e)) | E_assert (e1,e2) -> re (E_assert (map_exp e1,map_exp e2)) | E_internal_cast (ann,e) -> re (E_internal_cast (ann,map_exp e)) |
