summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBrian Campbell2017-09-28 11:09:24 +0100
committerBrian Campbell2017-09-28 11:09:24 +0100
commitb5969ea7ca7de19ea2b96c48b1765e2c51e5d2af (patch)
tree64fdca7e2bedaef9a78e63f5bfbc23a82c4d451b /src
parent8bbb538494a3fc17c2c6a2fda2106a3c07ac0ed9 (diff)
Refine constructors during monomorphisation
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml139
-rw-r--r--src/type_check.mli2
2 files changed, 73 insertions, 68 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index f5847db8..27b88ea7 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -320,8 +320,12 @@ let split_src_type id ty (TypQ_aux (q,ql)) =
let insts = cross'' insts in
let ty_and_inst (inst0,ty) inst =
let kids, ty = apply_kid_insts inst ty in
- let ty = Typ_aux (Typ_exist (kids, nc', ty),l) in
- inst@inst0, ty
+ let ty =
+ (* Typ_exist is not allowed an empty list of kids *)
+ match kids with
+ | [] -> ty
+ | _ -> Typ_aux (Typ_exist (kids, nc', ty),l)
+ in inst@inst0, ty
in
let tys = List.concat (List.map (fun instty -> List.map (ty_and_inst instty) insts) tys) in
let free = List.fold_left (fun vars k -> KidSet.remove k vars) vars kids in
@@ -335,6 +339,13 @@ let split_src_type id ty (TypQ_aux (q,ql)) =
match snd (size_nvars_ty typ) with
| [] -> []
| tys ->
+ (* One level of tuple type is stripped off by the type checker, so
+ add another here *)
+ let tys =
+ List.map (fun (x,ty) ->
+ x, match ty with
+ | Typ_aux (Typ_tup _,_) -> Typ_aux (Typ_tup [ty],Unknown)
+ | _ -> ty) tys in
if contains_exist t then
raise (Reporting_basic.err_general l
"Only prenex types in unions are supported by monomorphisation")
@@ -380,48 +391,57 @@ let reduce_nexp subst ne =
string_of_nexp nexp ^ " into concrete value"))
in eval ne
+
+let typ_of_args args =
+ match args with
+ | [E_aux (_,(l,annot))] ->
+ snd (env_typ_expected l annot)
+ | _ ->
+ let tys = List.map (fun (E_aux (_,(l,annot))) -> snd (env_typ_expected l annot)) args in
+ Typ_aux (Typ_tup tys,Unknown)
+
(* Check to see if we need to monomorphise a use of a constructor. Currently
assumes that bitvector sizes are always given as a variable; don't yet handle
more general cases (e.g., 8 * var) *)
-(* TODO: use type checker's instantiation instead *)
-let refine_constructor refinements id substs (E_aux (_,(l,_)) as arg) t =
- let rec derive_vars (Typ_aux (t,_)) (E_aux (e,(l,tannot))) =
- match t with
- | Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var v,_)),_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
- (match tannot with
- | Some (_,Typ_aux (Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp ne,_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]),_),_) ->
- [(v,reduce_nexp substs ne)]
- | _ -> [])
- | Typ_wild
- | Typ_var _
- | Typ_id _
- | Typ_fn _
- | Typ_app _
- -> []
- | Typ_tup ts ->
- match e with
- | E_tuple es -> List.concat (List.map2 derive_vars ts es)
- | _ -> [] (* TODO? *)
- in
- try
- let (_,irefinements) = List.find (fun (id',_) -> Id.compare id id' = 0) refinements in
- let vars = List.sort_uniq (fun x y -> Kid.compare (fst x) (fst y)) (derive_vars t arg) in
- try
- Some (List.assoc vars irefinements)
- with Not_found ->
- (Reporting_basic.print_err false true l "Monomorphisation"
- ("Failed to find a monomorphic constructor for " ^ string_of_id id ^ " instance " ^
- match vars with [] -> "<empty>"
- | _ -> String.concat "," (List.map (fun (x,y) -> string_of_kid x ^ "=" ^ string_of_int y) vars)); None)
- with Not_found -> None
+let refine_constructor refinements l env id args =
+ match List.find (fun (id',_) -> Id.compare id id' = 0) refinements with
+ | (_,irefinements) -> begin
+ let (_,constr_ty) = Env.get_val_spec id env in
+ match constr_ty with
+ | Typ_aux (Typ_fn (constr_ty,_,_),_) -> begin
+ let arg_ty = typ_of_args args in
+ match Type_check.destruct_exist env constr_ty with
+ | None -> None
+ | Some (kids,nc,constr_ty) ->
+ let (bindings,_,_) = Type_check.unify l env constr_ty arg_ty in
+ let find_kid kid = try Some (KBindings.find kid bindings) with Not_found -> None in
+ let bindings = List.map find_kid kids in
+ let matches_refinement (mapping,_,_) =
+ List.for_all2
+ (fun v (_,w) ->
+ match v,w with
+ | _,None -> true
+ | Some (U_nexp (Nexp_aux (Nexp_constant n, _))),Some m -> n = m
+ | _,_ -> false) bindings mapping
+ in
+ match List.find matches_refinement irefinements with
+ | (_,new_id,_) -> Some (E_app (new_id,args))
+ | exception Not_found ->
+ (Reporting_basic.print_err false true l "Monomorphisation"
+ ("Unable to refine constructor " ^ string_of_id id);
+ None)
+ end
+ | _ -> None
+ end
+ | exception Not_found -> None
(* Substitute found nexps for variables in an expression, and rename constructors to reflect
specialisation *)
(* TODO: kid shadowing *)
-let nexp_subst_fns substs refinements =
+let nexp_subst_fns substs =
let s_t t = subst_src_typ substs t in
(* let s_typschm (TypSchm_aux (TypSchm_ts (q,t),l)) = TypSchm_aux (TypSchm_ts (q,s_t t),l) in
@@ -464,25 +484,7 @@ let nexp_subst_fns substs refinements =
| E_internal_exp_user ((l1,annot1),(l2,annot2)) ->
re (E_internal_exp_user ((l1, s_tannot annot1),(l2, s_tannot annot2)))
| E_cast (t,e') -> re (E_cast (t, s_exp e'))
- | E_app (id,es) ->
- let es' = List.map s_exp es in
- let arg =
- match es' with
- | [] -> E_aux (E_lit (L_aux (L_unit,Unknown)),(l,None))
- | [e] -> e
- | _ -> E_aux (E_tuple es',(l,None))
- in
- let id' =
- let env,_ = env_typ_expected l annot in
- if Env.is_union_constructor id env then
- let (qs,ty) = Env.get_val_spec id env in
- match ty with (Typ_aux (Typ_fn(inty,outty,_),_)) ->
- (match refine_constructor refinements id substs arg inty with
- | None -> id
- | Some id' -> id')
- | _ -> id
- else id
- in re (E_app (id',es'))
+ | E_app (id,es) -> re (E_app (id, List.map s_exp es))
| E_app_infix (e1,id,e2) -> re (E_app_infix (s_exp e1,id,s_exp e2))
| E_tuple es -> re (E_tuple (List.map s_exp es))
| E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3))
@@ -542,8 +544,8 @@ let nexp_subst_fns substs refinements =
| LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (s_lexp le, s_exp e1, s_exp e2))
| LEXP_field (le,id) -> re (LEXP_field (s_lexp le, id))
in ((fun x -> x (*s_pat*)),s_exp)
-let nexp_subst_pat substs refinements = fst (nexp_subst_fns substs refinements)
-let nexp_subst_exp substs refinements = snd (nexp_subst_fns substs refinements)
+let nexp_subst_pat substs = fst (nexp_subst_fns substs)
+let nexp_subst_exp substs = snd (nexp_subst_fns substs)
let bindings_from_pat p =
let rec aux_pat (P_aux (p,(l,annot))) =
@@ -884,7 +886,11 @@ let split_defs splits defs =
(match try_app (l,annot) (id,es') with
| None ->
(match const_prop_try_fn (id,es') with
- | None -> re (E_app (id,es')) assigns
+ | None ->
+ (let env,_ = env_typ_expected l annot in
+ match Env.is_union_constructor id env, refine_constructor refinements l env id es' with
+ | true, Some exp -> re exp assigns
+ | _,_ -> re (E_app (id,es')) assigns)
| Some r -> r,assigns)
| Some r -> r,assigns)
| E_app_infix (e1,id,e2) ->
@@ -968,7 +974,7 @@ let split_defs splits defs =
let assigns' = isubst_minus_set assigns assigned_in in
re (E_case (e', List.map (const_prop_pexp substs assigns) cases)) assigns'
| Some (E_aux (_,(_,annot')) as exp,newbindings,kbindings) ->
- let exp = nexp_subst_exp (ksubst_from_list kbindings) [] (*???*) exp in
+ let exp = nexp_subst_exp (ksubst_from_list kbindings) exp in
let newbindings_env = isubst_from_list newbindings in
let substs' = isubst_union substs newbindings_env in
const_prop_exp substs' assigns exp)
@@ -991,7 +997,7 @@ let split_defs splits defs =
match can_match e' [Pat_aux (Pat_exp (p,e2),(Unknown,None))] with
| None -> plain ()
| Some (e'',bindings,kbindings) ->
- let e'' = nexp_subst_exp (ksubst_from_list kbindings) [] (*???*) e'' in
+ let e'' = nexp_subst_exp (ksubst_from_list kbindings) e'' in
let bindings = isubst_from_list bindings in
let substs'' = isubst_union substs' bindings in
const_prop_exp substs'' assigns e''
@@ -1372,9 +1378,8 @@ let split_defs splits defs =
patsubsts
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
- (* Leave refinements to later *)
- let pat' = nexp_subst_pat nsubst [] pat' in
- let exp' = nexp_subst_exp nsubst [] e in
+ let pat' = nexp_subst_pat nsubst pat' in
+ let exp' = nexp_subst_exp nsubst e in
Pat_aux (Pat_exp (pat', map_exp exp'),l)
) patnsubsts)
| Pat_aux (Pat_when (p,e1,e2),l) ->
@@ -1388,10 +1393,9 @@ let split_defs splits defs =
patsubsts
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
- (* Leave refinements to later *)
- let pat' = nexp_subst_pat nsubst [] pat' in
- let exp1' = nexp_subst_exp nsubst [] e1 in
- let exp2' = nexp_subst_exp nsubst [] e2 in
+ let pat' = nexp_subst_pat nsubst pat' in
+ let exp1' = nexp_subst_exp nsubst e1 in
+ let exp2' = nexp_subst_exp nsubst e2 in
Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)
) patnsubsts)
and map_letbind (LB_aux (lb,annot)) =
@@ -1421,9 +1425,8 @@ let split_defs splits defs =
patsubsts
| ConstrSplit patnsubsts ->
List.map (fun (pat',nsubst) ->
- (* Leave refinements to later *)
- let pat' = nexp_subst_pat nsubst [] pat' in
- let exp' = nexp_subst_exp nsubst [] exp in
+ let pat' = nexp_subst_pat nsubst pat' in
+ let exp' = nexp_subst_exp nsubst exp in
FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot)
) patnsubsts
in
diff --git a/src/type_check.mli b/src/type_check.mli
index ca2fb90c..e5279067 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -225,6 +225,8 @@ type uvar =
val string_of_uvar : uvar -> string
+val unify : l -> Env.t -> typ -> typ -> uvar KBindings.t * kid list * n_constraint option
+
(* Throws Invalid_argument if the argument is not a E_app expression *)
val instantiation_of : tannot exp -> uvar KBindings.t