summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBrian Campbell2017-07-07 11:05:02 +0100
committerBrian Campbell2017-07-07 14:07:53 +0100
commit10caa78f7d11bae716c714587e059d18cee51476 (patch)
treef5b0e200b4e1f53d38e2eded87dd3ab6541f7152 /src
parent9cb879efde58abfd5cc4ae8b2d0344902c983cde (diff)
Implement basic monomorphisation of constructors
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml319
1 files changed, 302 insertions, 17 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 14ea30ba..2e78b6b5 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -2,6 +2,8 @@ open Parse_ast
open Ast
open Type_internal
+let disable_const_propagation = ref false
+
(* TODO: put this somewhere common *)
let id_to_string (Id_aux(id,l)) =
@@ -16,7 +18,67 @@ let optmap v f =
| None -> None
| Some v -> Some (f v)
-let disable_const_propagation = ref false
+(* We need to be able to substitute in source types to monomorphise
+ datatype constructors *)
+
+let rec nexp_to_snexp ne =
+ let mkne n = Nexp_aux (n,Unknown) in
+ match ne.nexp with
+ | Nvar s -> mkne (Nexp_var (Kid_aux (Var s,Unknown)))
+ | Nid (s,ne) -> mkne (Nexp_id (Id_aux (Id s,Unknown)))
+ | Nconst i -> mkne (Nexp_constant (Big_int.int_of_big_int i))
+ | Npos_inf
+ | Nneg_inf ->
+ raise (Reporting_basic.err_general Unknown ("Can't translate inf nexps back into source nexps"))
+ | Nadd (n1,n2) -> mkne (Nexp_sum (nexp_to_snexp n1, nexp_to_snexp n2))
+ | Nsub (n1,n2) -> mkne (Nexp_minus (nexp_to_snexp n1, nexp_to_snexp n2))
+ | Nmult (n1,n2) -> mkne (Nexp_times (nexp_to_snexp n1, nexp_to_snexp n2))
+ | N2n (ne,_) -> mkne (Nexp_exp (nexp_to_snexp ne))
+ | Npow _ ->
+ raise (Reporting_basic.err_general Unknown ("Can't translate pow nexps back into source nexps"))
+ | Nneg ne -> mkne (Nexp_neg (nexp_to_snexp ne))
+ | Ninexact ->
+ raise (Reporting_basic.err_general Unknown ("Can't translate inexact nexps back into source nexps"))
+ | Nuvar _ ->
+ raise (Reporting_basic.err_general Unknown ("Can't translate uvar nexps back into source nexps"))
+
+
+let subst_src_typ substs t =
+ let rec s_snexp (Nexp_aux (ne,l) as nexp) =
+ let re ne = Nexp_aux (ne,l) in
+ match ne with
+ | Nexp_var (Kid_aux (Var i,l)) ->
+ (match Envmap.apply substs i with
+ | Some (TA_nexp ne) -> nexp_to_snexp ne
+ | _ -> re (Nexp_var (Kid_aux(Var i,Generated l))))
+ | Nexp_id _
+ | Nexp_constant _ -> nexp
+ | Nexp_times (n1,n2) -> re (Nexp_times (s_snexp n1, s_snexp n2))
+ | Nexp_sum (n1,n2) -> re (Nexp_sum (s_snexp n1, s_snexp n2))
+ | Nexp_minus (n1,n2) -> re (Nexp_minus (s_snexp n1, s_snexp n2))
+ | Nexp_exp ne -> re (Nexp_exp (s_snexp ne))
+ | Nexp_neg ne -> re (Nexp_neg (s_snexp ne))
+ in
+ let rec s_styp ((Typ_aux (t,l)) as ty) =
+ let re t = Typ_aux (t,l) in
+ match t with
+ | Typ_wild
+ | Typ_id _
+ | Typ_var _
+ -> ty
+ | Typ_fn (t1,t2,e) -> re (Typ_fn (s_styp t1, s_styp t2,e))
+ | Typ_tup ts -> re (Typ_tup (List.map s_styp ts))
+ | Typ_app (id,tas) -> re (Typ_app (id,List.map s_starg tas))
+ and s_starg (Typ_arg_aux (ta,l) as targ) =
+ match ta with
+ | Typ_arg_nexp ne -> Typ_arg_aux (Typ_arg_nexp (s_snexp ne),l)
+ | Typ_arg_typ t -> Typ_arg_aux (Typ_arg_typ (s_styp t),l)
+ | Typ_arg_order _
+ | Typ_arg_effect _ -> targ
+ in s_styp t
+
+
+
(* Based on current type checker's behaviour *)
let pat_id_is_variable t_env id =
@@ -25,9 +87,145 @@ let pat_id_is_variable t_env id =
| Some (Base(_,Enum _,_,_,_,_))
-> false
| _ -> true
-
-let nexp_subst substs exp =
+let rec list_extract f = function
+ | [] -> None
+ | h::t -> match f h with None -> list_extract f t | Some v -> Some v
+
+let rec cross = function
+ | [] -> failwith "cross"
+ | [(x,l)] -> List.map (fun y -> [(x,y)]) l
+ | (x,l)::t ->
+ let t' = cross t in
+ List.concat (List.map (fun y -> List.map (fun l' -> (x,y)::l') t') l)
+
+(* Given a type for a constructor, work out which refinements we ought to produce *)
+(* TODO collision avoidance *)
+let split_src_type i ty (TypQ_aux (q,ql)) =
+ let rec size_nvars_nexp (Nexp_aux (ne,_)) =
+ match ne with
+ | Nexp_var (Kid_aux (Var v,_)) -> [v]
+ | Nexp_id _
+ | Nexp_constant _
+ -> []
+ | Nexp_times (n1,n2)
+ | Nexp_sum (n1,n2)
+ | Nexp_minus (n1,n2)
+ -> size_nvars_nexp n1 @ size_nvars_nexp n2
+ | Nexp_exp n
+ | Nexp_neg n
+ -> size_nvars_nexp n
+ in
+ let rec size_nvars_ty (Typ_aux (ty,l)) =
+ match ty with
+ | Typ_wild
+ | Typ_id _
+ | Typ_var _
+ -> []
+ | Typ_fn _ ->
+ raise (Reporting_basic.err_general l ("Function type in constructor " ^ i))
+ | Typ_tup ts -> List.concat (List.map size_nvars_ty ts)
+ | Typ_app (Id_aux (Id "vector",_),
+ [_;Typ_arg_aux (Typ_arg_nexp sz,_);
+ _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
+ size_nvars_nexp sz
+ | Typ_app (_, tas) ->
+ [] (* We only support sizes for bitvectors mentioned explicitly, not any buried
+ inside another type *)
+ in
+ let nvars = List.sort_uniq (String.compare) (size_nvars_ty ty) in
+ match nvars with
+ | [] -> None
+ | sample::__ ->
+ (* Only check for constraints if we found a size to constrain *)
+ let qs =
+ match q with
+ | TypQ_no_forall ->
+ raise (Reporting_basic.err_general ql
+ ("No set constraint for variable " ^ sample ^ " in constructor " ^ i))
+ | TypQ_tq qs -> qs
+ in
+ let find_set nvar =
+ match list_extract (function
+ | QI_aux (QI_const (NC_aux (NC_nat_set_bounded (Kid_aux (Var nvar',_),vals),_)),_)
+ -> if nvar = nvar' then Some vals else None
+ | _ -> None) qs with
+ | None ->
+ raise (Reporting_basic.err_general ql
+ ("No set constraint for variable " ^ nvar ^ " in constructor " ^ i))
+ | Some vals -> (nvar,vals)
+ in
+ let nvar_sets = List.map find_set nvars in
+ let total_variants = List.fold_left ( * ) 1 (List.map (fun (_,l) -> List.length l) nvar_sets) in
+ let limit = 8 in
+ let () = if total_variants > limit then
+ raise (Reporting_basic.err_general ql
+ (string_of_int total_variants ^ "variants for constructor " ^ i ^
+ "bigger than limit " ^ string_of_int limit)) else ()
+ in
+ let variants = cross nvar_sets in
+ let name l = String.concat "_" (i::(List.map (fun (v,i) -> v ^ string_of_int i) l)) in
+ Some (List.map (fun l -> (l, name l)) variants)
+
+(* TODO: maybe fold this into subst_src_typ? *)
+let inst_src_type insts ty =
+ let insts = List.map (fun (v,i) -> (v,TA_nexp {nexp=Nconst (Big_int.big_int_of_int i); imp_param=false})) insts in
+ let subst = Envmap.from_list insts in
+ subst_src_typ subst ty
+
+let reduce_nexp subst ne =
+ let ne = subst_n_with_env subst ne in
+ let rec eval ne =
+ match ne.nexp with
+ | Nconst i -> i
+ | Nadd (n1,n2) -> Big_int.add_big_int (eval n1) (eval n2)
+ | Nsub (n1,n2) -> Big_int.sub_big_int (eval n1) (eval n2)
+ | Nmult (n1,n2) -> Big_int.mult_big_int (eval n1) (eval n2)
+ | N2n (n,_) -> Big_int.power_int_positive_big_int 2 (eval n)
+ | Npow (n,i) -> Big_int.power_big_int_positive_int (eval n) i
+ | Nneg n -> Big_int.minus_big_int (eval n)
+ | _ ->
+ raise (Reporting_basic.err_general Unknown ("Couldn't turn nexp " ^
+ n_to_string ne ^ " into concrete value"))
+ in Big_int.int_of_big_int (eval ne)
+
+(* Check to see if we need to monomorphise a use of a constructor. Currently
+ assumes that bitvector sizes are always given as a variable; don't yet handle
+ more general cases (e.g., 8 * var) *)
+
+let refine_constructor refinements i substs arg t =
+ let rec derive_vars t (E_aux (e,(l,tannot)) as exp) =
+ match t.t with
+ | Tapp ("vector", [_;TA_nexp {nexp = Nvar v};_;TA_typ {t=Tid "bit"}]) ->
+ (match tannot with
+ | Base ((_,{t=Tapp ("vector", [_;TA_nexp ne;_;TA_typ {t=Tid "bit"}])}),_,_,_,_,_) ->
+ [(v,reduce_nexp substs ne)]
+ | _ -> [])
+ | Tvar _
+ | Tid _
+ | Tfn _
+ | Tapp _
+ | Tuvar _
+ -> []
+ | Tabbrev (_,t)
+ | Toptions (t,_)
+ -> derive_vars t exp
+ | Ttup ts ->
+ match e with
+ | E_tuple es -> List.concat (List.map2 derive_vars ts es)
+ | _ -> [] (* TODO? *)
+ in
+ try
+ let irefinements = List.assoc i refinements in
+ let vars = List.sort_uniq (fun x y -> String.compare (fst x) (fst y)) (derive_vars t arg) in
+ Some (List.assoc vars irefinements)
+ with Not_found -> None
+
+
+(* Substitute found nexps for variables in an expression, and rename constructors to reflect
+ specialisation *)
+
+let nexp_subst_fns t_env substs refinements =
let s_t t = typ_subst substs true t in
(* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in
hopefully don't need this anyway *)
@@ -68,7 +266,23 @@ let nexp_subst substs exp =
| E_internal_exp_user ((l1,annot1),(l2,annot2)) ->
re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2)))
| E_cast (t,e') -> re (E_cast (t, s_exp e'))
- | E_app (id,es) -> re (E_app (id,List.map s_exp es))
+ | E_app (id,es) ->
+ let es' = List.map s_exp es in
+ let arg =
+ match es' with
+ | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(Unknown,simple_annot unit_t))
+ | [e] -> e
+ | _ -> E_aux (E_tuple es',(Unknown,NoTyp))
+ in
+ let i = id_to_string id in
+ let id' =
+ match Envmap.apply t_env i with
+ | Some (Base((t_params,{t=Tfn(inty,outty,_,_)}),Constructor _,_,_,_,_)) ->
+ (match refine_constructor refinements i substs arg inty with
+ | None -> id
+ | Some i -> Id_aux (Id i,Generated l))
+ | _ -> id
+ in re (E_app (id',es'))
| E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2))
| E_tuple es -> re (E_tuple (List.map s_exp es))
| E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3))
@@ -123,7 +337,9 @@ let nexp_subst substs exp =
| LEXP_vector (le,e) -> re (LEXP_vector (s_lexp le, s_exp e))
| LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2))
| LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id))
- in s_exp exp
+ in (s_pat,s_exp)
+let nexp_subst_pat t_env substs refinements = fst (nexp_subst_fns t_env substs refinements)
+let nexp_subst_exp t_env substs refinements = snd (nexp_subst_fns t_env substs refinements)
let bindings_from_pat t_env p =
let rec aux_pat (P_aux (p,annot)) =
@@ -151,10 +367,47 @@ let remove_bound t_env env pat =
let bound = bindings_from_pat t_env pat in
List.fold_left (fun sub v -> Envmap.remove env v) env bound
+(* We may need to split up a pattern match if (1) we've been told to case split
+ on a variable by the user, or (2) we monomorphised a constructor that's used
+ in the pattern. *)
+type split =
+ | NoSplit
+ | VarSplit of (tannot pat * (string * tannot Ast.exp)) list
+ | ConstrSplit of (tannot pat * t_arg Envmap.t) list
let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
+ let split_constructors (Defs defs) =
+ let sc_type_union q (Tu_aux (tu,l) as tua) =
+ match tu with
+ | Tu_id id -> [],[tua]
+ | Tu_ty_id (ty,id) ->
+ let i = id_to_string id in
+ (match split_src_type i ty q with
+ | None -> ([],[Tu_aux (Tu_ty_id (ty,id),l)])
+ | Some variants ->
+ ([(i,variants)],
+ List.map (fun (insts, i') -> Tu_aux (Tu_ty_id (inst_src_type insts ty,Id_aux (Id i',Generated l)),Generated l)) variants))
+ in
+ let sc_type_def ((TD_aux (tda,annot)) as td) =
+ match tda with
+ | TD_variant (id,nscm,quant,tus,flag) ->
+ let (refinements, tus') = List.split (List.map (sc_type_union quant) tus) in
+ (List.concat refinements, TD_aux (TD_variant (id,nscm,quant,List.concat tus',flag),annot))
+ | _ -> ([],td)
+ in
+ let sc_def d =
+ match d with
+ | DEF_type td -> let (refinements,td') = sc_type_def td in (refinements, DEF_type td')
+ | _ -> ([], d)
+ in
+ let (refinements, defs') = List.split (List.map sc_def defs)
+ in (List.concat refinements, Defs defs')
+ in
+
+ let (refinements, defs') = split_constructors defs in
- let can_match (E_aux (e,(l,annot)) as exp) cases =
+ (* Attempt simple pattern matches *)
+ let can_match (E_aux (e,(l,annot))) cases =
match e with
| E_id id ->
let i = id_to_string id in
@@ -192,6 +445,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
| _ -> None
in
+ (* Similarly, simple conditionals *)
(* TODO: doublecheck *)
let lit_eq (L_aux (l1,_)) (L_aux (l2,_)) =
match l1,l2 with
@@ -215,6 +469,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
| _ -> None
in
+ (* Extract nvar substitution by comparing two types *)
let build_nexp_subst l t1 t2 =
let rec from_types t1 t2 =
let t1 = match t1.t with Tabbrev(_,t) -> t | _ -> t1 in
@@ -374,7 +629,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
(* Substitute what we've learned about nvars into the term *)
let nsubsts = Envmap.from_list (List.map (fun (id,ne) -> (id,TA_nexp ne)) !nexp_substs) in
let () = nexp_substs := [] in
- nexp_subst nsubsts exp'
+ nexp_subst_exp t_env nsubsts refinements exp'
in
@@ -428,7 +683,7 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
Filename.basename p.Lexing.pos_fname = filename &&
p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum) ls
in
-
+
let split_pat var p =
let rec list f = function
| [] -> None
@@ -485,14 +740,30 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
relist spl (fun ps -> P_list ps) ps
in spl p
in
-
- let map_pat (P_aux (_,(l,_)) as p) =
+
+ let map_pat_by_loc (P_aux (p,(l,_)) as pat) =
match match_l l with
| [] -> None
- | [(_,var)] -> split_pat var p
+ | [(_,var)] -> split_pat var pat
| lvs -> raise (Reporting_basic.err_general l
("Multiple variables to split on: " ^ String.concat ", " (List.map snd lvs)))
in
+ let map_pat (P_aux (p,(l,tannot)) as pat) =
+ match map_pat_by_loc pat with
+ | Some l -> VarSplit l
+ | None ->
+ match p with
+ | P_app (id,args) ->
+ (try
+ let variants = List.assoc (id_to_string id) refinements in
+ let map_inst (insts,i') =
+ P_aux (P_app (Id_aux (Id i',Generated l),args),(Generated l,tannot)),
+ Envmap.from_list (List.map (fun (v,i) -> (v,TA_nexp {nexp=Nconst (Big_int.big_int_of_int i);imp_param=false})) insts)
+ in
+ ConstrSplit (List.map map_inst variants)
+ with Not_found -> NoSplit)
+ | _ -> NoSplit
+ in
let check_single_pat (P_aux (_,(l,_)) as p) =
match match_l l with
@@ -560,12 +831,19 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
FE_aux (FE_Fexp (id,map_exp e),annot)
and map_pexp (Pat_aux (Pat_exp (p,e),l)) =
match map_pat p with
- | None -> [Pat_aux (Pat_exp (p,map_exp e),l)]
- | Some patsubsts ->
+ | NoSplit -> [Pat_aux (Pat_exp (p,map_exp e),l)]
+ | VarSplit patsubsts ->
List.map (fun (pat',subst) ->
let exp' = subst_exp subst e in
Pat_aux (Pat_exp (pat', map_exp exp'),l))
patsubsts
+ | ConstrSplit patnsubsts ->
+ List.map (fun (pat',nsubst) ->
+ (* Leave refinements to later *)
+ let pat' = nexp_subst_pat t_env nsubst [] pat' in
+ let exp' = nexp_subst_exp t_env nsubst [] e in
+ Pat_aux (Pat_exp (pat', map_exp exp'),l)
+ ) patnsubsts
and map_letbind (LB_aux (lb,annot)) =
match lb with
| LB_val_explicit (tysch,p,e) -> LB_aux (LB_val_explicit (tysch,check_single_pat p,map_exp e), annot)
@@ -585,12 +863,19 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
let map_funcl (FCL_aux (FCL_Funcl (id,pat,exp),annot)) =
match map_pat pat with
- | None -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)]
- | Some patsubsts ->
+ | NoSplit -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)]
+ | VarSplit patsubsts ->
List.map (fun (pat',subst) ->
let exp' = subst_exp subst exp in
FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot))
patsubsts
+ | ConstrSplit patnsubsts ->
+ List.map (fun (pat',nsubst) ->
+ (* Leave refinements to later *)
+ let pat' = nexp_subst_pat t_env nsubst [] pat' in
+ let exp' = nexp_subst_exp t_env nsubst [] exp in
+ FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot)
+ ) patnsubsts
in
let map_fundef (FD_aux (FD_function (r,t,e,fcls),annot)) =
@@ -617,5 +902,5 @@ let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =
in
Defs (List.concat (List.map map_def defs))
-
- in map_locs splits defs
+ in
+ map_locs splits defs'