summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_util.ml1
-rw-r--r--src/ast_util.mli4
-rw-r--r--src/monomorphise.ml208
3 files changed, 114 insertions, 99 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml
index 27ae93e8..909fa392 100644
--- a/src/ast_util.ml
+++ b/src/ast_util.ml
@@ -182,6 +182,7 @@ module IdSet = Set.Make(Id)
module KBindings = Map.Make(Kid)
module KidSet = Set.Make(Kid)
module NexpSet = Set.Make(Nexp)
+module NexpMap = Map.Make(Nexp)
let rec nexp_identical nexp1 nexp2 = (Nexp.compare nexp1 nexp2 = 0)
diff --git a/src/ast_util.mli b/src/ast_util.mli
index bbbde27f..4671ee36 100644
--- a/src/ast_util.mli
+++ b/src/ast_util.mli
@@ -248,6 +248,10 @@ module NexpSet : sig
include Set.S with type elt = nexp
end
+module NexpMap : sig
+ include Map.S with type key = nexp
+end
+
module BESet : sig
include Set.S with type elt = base_effect
end
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 71efcb22..9e9e377e 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -1917,27 +1917,27 @@ let findi f =
let mapat f is xs =
let rec aux n = function
- | _, [] -> []
- | (i,_)::is, h::t when i = n ->
+ | [] -> []
+ | h::t when Util.IntSet.mem n is ->
let h' = f h in
- let t' = aux (n+1) (is, t) in
+ let t' = aux (n+1) t in
h'::t'
- | is, h::t ->
- let t' = aux (n+1) (is, t) in
+ | h::t ->
+ let t' = aux (n+1) t in
h::t'
- in aux 0 (is, xs)
+ in aux 0 xs
let mapat_extra f is xs =
let rec aux n = function
- | _, [] -> [], []
- | (i,v)::is, h::t when i = n ->
- let h',x = f v h in
- let t',xs = aux (n+1) (is, t) in
+ | [] -> [], []
+ | h::t when Util.IntSet.mem n is ->
+ let h',x = f h in
+ let t',xs = aux (n+1) t in
h'::t',x::xs
- | is, h::t ->
- let t',xs = aux (n+1) (is, t) in
+ | h::t ->
+ let t',xs = aux (n+1) t in
h::t',xs
- in aux 0 (is, xs)
+ in aux 0 xs
let tyvars_bound_in_pat pat =
let open Rewriter in
@@ -1975,34 +1975,45 @@ let sizes_of_annot = function
| _,None -> KidSet.empty
| _,Some (env,typ,_) -> sizes_of_typ (Env.base_typ_of env typ)
-let change_parameter_pat kid = function
- | P_aux (P_id var, (l,_))
- | P_aux (P_typ (_,P_aux (P_id var, (l,_))),_)
- -> P_aux (P_id var, (l,None)), (var,kid)
+let change_parameter_pat = function
+ | P_aux (P_id var, (l,Some (env,typ,_)))
+ | P_aux (P_typ (_,P_aux (P_id var, (l,Some (env,typ,_)))),_) ->
+ P_aux (P_id var, (l,None)), var
| P_aux (_,(l,_)) -> raise (Reporting_basic.err_unreachable l
"Expected variable pattern")
(* We add code to change the itself('n) parameter into the corresponding
integer. *)
-let add_var_rebind exp (var,kid) =
+let add_var_rebind exp var =
let l = Generated Unknown in
let annot = (l,None) in
E_aux (E_let (LB_aux (LB_val (P_aux (P_id var,annot),
E_aux (E_app (mk_id "size_itself_int",[E_aux (E_id var,annot)]),annot)),annot),exp),annot)
(* atom('n) arguments to function calls need to be rewritten *)
-let replace_with_the_value (E_aux (_,(l,_)) as exp) =
+let replace_with_the_value bound_nexps (E_aux (_,(l,_)) as exp) =
let env = env_of exp in
let typ, wrap = match typ_of exp with
| Typ_aux (Typ_exist (kids,nc,typ),l) -> typ, fun t -> Typ_aux (Typ_exist (kids,nc,t),l)
| typ -> typ, fun x -> x
in
let typ = Env.expand_synonyms env typ in
+ let replace_size size =
+ (* TODO: pick simpler nexp when there's a choice (also in pretty printer) *)
+ let is_equal nexp =
+ prove env (NC_aux (NC_equal (size,nexp), Parse_ast.Unknown))
+ in
+ if is_nexp_constant size then size else
+ match List.find is_equal bound_nexps with
+ | nexp -> nexp
+ | exception Not_found -> size
+ in
let mk_exp nexp l l' =
- E_aux (E_cast (wrap (Typ_aux (Typ_app (Id_aux (Id "itself",Generated Unknown),
- [Typ_arg_aux (Typ_arg_nexp nexp,l')]),Generated Unknown)),
- E_aux (E_app (Id_aux (Id "make_the_value",Generated Unknown),[exp]),(Generated l,None))),
- (Generated l,None))
+ let nexp = replace_size nexp in
+ E_aux (E_cast (wrap (Typ_aux (Typ_app (Id_aux (Id "itself",Generated Unknown),
+ [Typ_arg_aux (Typ_arg_nexp nexp,l')]),Generated Unknown)),
+ E_aux (E_app (Id_aux (Id "make_the_value",Generated Unknown),[exp]),(Generated l,None))),
+ (Generated l,None))
in
match typ with
| Typ_aux (Typ_app (Id_aux (Id "range",_),
@@ -2032,91 +2043,79 @@ let replace_type env typ =
let rewrite_size_parameters env (Defs defs) =
let open Rewriter in
- let size_vars pexp =
- fst (fold_pexp
- { (compute_exp_alg KidSet.empty KidSet.union) with
- e_aux = (fun ((s,e),annot) -> KidSet.union s (sizes_of_annot annot), E_aux (e,annot));
- e_let = (fun ((sl,lb),(s2,e2)) -> KidSet.union sl (KidSet.diff s2 (tyvars_bound_in_lb lb)), E_let (lb,e2));
- e_for = (fun (id,(s1,e1),(s2,e2),(s3,e3),ord,(s4,e4)) ->
- let kid = mk_kid ("loop_" ^ string_of_id id) in
- KidSet.union s1 (KidSet.union s2 (KidSet.union s3 (KidSet.remove kid s4))),
- E_for (id,e1,e2,e3,ord,e4));
- pat_exp = (fun ((sp,pat),(s,e)) -> KidSet.diff s (tyvars_bound_in_pat pat), Pat_exp (pat,e))}
- pexp)
- in
- let exposed_sizes_funcl fnsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) =
- let sizes = size_vars pexp in
- let pat,guard,exp,pannot = destruct_pexp pexp in
- let visible_tyvars =
- KidSet.union
- (Pretty_print_lem.lem_tyvars_of_typ (pat_typ_of pat))
- (Pretty_print_lem.lem_tyvars_of_typ (typ_of exp))
- in
- let expose_tyvars = KidSet.diff sizes visible_tyvars in
- KidSet.union fnsizes expose_tyvars
- in
- let sizes_funcl expose_tyvars fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) =
+ let open Util in
+
+ let sizes_funcl fsizes (FCL_aux (FCL_Funcl (id,pexp),(l,_))) =
let pat,guard,exp,pannot = destruct_pexp pexp in
let parameters = match pat with
| P_aux (P_tup ps,_) -> ps
| _ -> [pat]
in
- let to_change = Util.map_filter
- (fun kid ->
- let check (P_aux (_,(_,Some (env,typ,_)))) =
- match Env.expand_synonyms env typ with
- Typ_aux (Typ_app(Id_aux (Id "range",_),
- [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_);
- Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid'',_)),_)]),_) ->
- if Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 then Some kid else None
- | Typ_aux (Typ_app(Id_aux (Id "atom", _),
- [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid',_)),_)]), _) ->
- if Kid.compare kid kid' = 0 then Some kid else None
- | _ -> None
- in match findi check parameters with
- | None -> (Reporting_basic.print_error (Reporting_basic.Err_general (l,
- ("Unable to find an argument for " ^ string_of_kid kid)));
- None)
- | Some i -> Some i)
- (KidSet.elements expose_tyvars)
+ let add_parameter (i,nmap) (P_aux (_,(_,Some (env,typ,_)))) =
+ let nmap =
+ match Env.base_typ_of env typ with
+ Typ_aux (Typ_app(Id_aux (Id "range",_),
+ [Typ_arg_aux (Typ_arg_nexp nexp,_);
+ Typ_arg_aux (Typ_arg_nexp nexp',_)]),_)
+ when Nexp.compare nexp nexp' = 0 && not (NexpMap.mem nexp nmap) ->
+ NexpMap.add nexp i nmap
+ | Typ_aux (Typ_app(Id_aux (Id "atom", _),
+ [Typ_arg_aux (Typ_arg_nexp nexp,_)]), _)
+ when not (NexpMap.mem nexp nmap) ->
+ NexpMap.add nexp i nmap
+ | _ -> nmap
+ in (i+1,nmap)
in
- let ik_compare (i,k) (i',k') =
- match compare (i : int) i' with
- | 0 -> Kid.compare k k'
- | x -> x
+ let (_,nexp_map) = List.fold_left add_parameter (0,NexpMap.empty) parameters in
+ let nexp_list = NexpMap.bindings nexp_map in
+ let parameters_for = function
+ | Some (env,typ,_) ->
+ begin match Env.base_typ_of env typ with
+ | Typ_aux (Typ_app (Id_aux (Id "vector",_), [Typ_arg_aux (Typ_arg_nexp size,_);_;_]),_)
+ when not (is_nexp_constant size) ->
+ begin
+ match NexpMap.find size nexp_map with
+ | i -> IntSet.singleton i
+ | exception Not_found ->
+ (* Look for equivalent nexps, but only in consistent type env *)
+ if prove env (NC_aux (NC_false,Unknown)) then IntSet.empty else
+ match List.find (fun (nexp,i) ->
+ prove env (NC_aux (NC_equal (nexp,size),Unknown))) nexp_list with
+ | _, i -> IntSet.singleton i
+ | exception Not_found -> IntSet.empty
+ end
+ | _ -> IntSet.empty
+ end
+ | None -> IntSet.empty
in
- let to_change = List.sort ik_compare to_change in
+ let parameters_to_rewrite =
+ fst (fold_pexp
+ { (compute_exp_alg IntSet.empty IntSet.union) with
+ e_aux = (fun ((s,e),(l,annot)) -> IntSet.union s (parameters_for annot),E_aux (e,(l,annot)))
+ } pexp)
+ in
+ let new_nexps = NexpSet.of_list (List.map fst
+ (List.filter (fun (nexp,i) -> IntSet.mem i parameters_to_rewrite) nexp_list)) in
+let _ = print_endline ("Fn " ^ string_of_id id ^ " rewrite " ^
+ String.concat "," (List.map string_of_int (IntSet.elements parameters_to_rewrite))) in
match Bindings.find id fsizes with
- | old -> if List.for_all2 (fun x y -> ik_compare x y = 0) old to_change then fsizes else
- let str l = String.concat "," (List.map (fun (i,k) -> string_of_int i ^ "." ^ string_of_kid k) l) in
- raise (Reporting_basic.err_general l
- ("Different size type variables in different clauses of " ^ string_of_id id ^
- " old: " ^ str old ^ " new: " ^ str to_change))
- | exception Not_found -> Bindings.add id to_change fsizes
+ | old,old_nexps -> Bindings.add id (IntSet.union old parameters_to_rewrite,
+ NexpSet.union old_nexps new_nexps) fsizes
+ | exception Not_found -> Bindings.add id (parameters_to_rewrite, new_nexps) fsizes
in
let sizes_def fsizes = function
| DEF_fundef (FD_aux (FD_function (_,_,_,funcls),_)) ->
- let expose_tyvars = List.fold_left exposed_sizes_funcl KidSet.empty funcls in
- List.fold_left (sizes_funcl expose_tyvars) fsizes funcls
+ List.fold_left sizes_funcl fsizes funcls
| _ -> fsizes
in
let fn_sizes = List.fold_left sizes_def Bindings.empty defs in
- let rewrite_e_app (id,args) =
- match Bindings.find id fn_sizes with
- | [] -> E_app (id,args)
- | to_change ->
- let args' = mapat replace_with_the_value to_change args in
- E_app (id,args')
- | exception Not_found -> E_app (id,args)
- in
let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),(l,annot))) =
let pat,guard,body,(pl,_) = destruct_pexp pexp in
- let pat,guard,body =
+ let pat,guard,body, nexps =
(* Update pattern and add itself -> nat wrapper to body *)
match Bindings.find id fn_sizes with
- | [] -> pat,guard,body
- | to_change ->
+ | to_change,nexps ->
let pat, vars =
match pat with
P_aux (P_tup pats,(l,_)) ->
@@ -2124,13 +2123,10 @@ let rewrite_size_parameters env (Defs defs) =
P_aux (P_tup pats,(l,None)), vars
| P_aux (_,(l,_)) ->
begin
- match to_change with
- | [0,kid] ->
- let pat, var = change_parameter_pat kid pat in
+ if IntSet.is_empty to_change then pat, []
+ else
+ let pat, var = change_parameter_pat pat in
pat, [var]
- | _ ->
- raise (Reporting_basic.err_unreachable l
- "Expected multiple parameters at single parameter")
end
in
(* TODO: only add bindings that are necessary (esp for guards) *)
@@ -2139,10 +2135,24 @@ let rewrite_size_parameters env (Defs defs) =
| None -> None
| Some exp -> Some (List.fold_left add_var_rebind exp vars)
in
- pat,guard,body
- | exception Not_found -> pat,guard,body
+ pat,guard,body,nexps
+ | exception Not_found -> pat,guard,body,NexpSet.empty
in
(* Update function applications *)
+ let funcl_typ = typ_of_annot (l,annot) in
+ let already_visible_nexps =
+ NexpSet.union
+ (Pretty_print_lem.lem_nexps_of_typ funcl_typ)
+ (Pretty_print_lem.typeclass_nexps funcl_typ)
+ in
+ let bound_nexps = NexpSet.elements (NexpSet.union nexps already_visible_nexps) in
+ let rewrite_e_app (id,args) =
+ match Bindings.find id fn_sizes with
+ | to_change,_ ->
+ let args' = mapat (replace_with_the_value bound_nexps) to_change args in
+ E_app (id,args')
+ | exception Not_found -> E_app (id,args)
+ in
let body = fold_exp { id_exp_alg with e_app = rewrite_e_app } body in
let guard = match guard with
| None -> None
@@ -2156,8 +2166,7 @@ let rewrite_size_parameters env (Defs defs) =
| DEF_spec (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,annot))) as spec ->
begin
match Bindings.find id fn_sizes with
- | [] -> spec
- | to_change ->
+ | to_change,_ when not (IntSet.is_empty to_change) ->
let typschm = match typschm with
| TypSchm_aux (TypSchm_ts (tq,typ),l) ->
let typ = match typ with
@@ -2169,6 +2178,7 @@ let rewrite_size_parameters env (Defs defs) =
in TypSchm_aux (TypSchm_ts (tq,typ),l)
in
DEF_spec (VS_aux (VS_val_spec (typschm,id,extern,cast),(l,None)))
+ | _ -> spec
| exception Not_found -> spec
end
| def -> def