summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorThomas Bauereiss2018-01-25 11:18:10 +0000
committerThomas Bauereiss2018-01-25 11:41:22 +0000
commit54d18f2d19f33aae822dca53485afa8ba9e06e81 (patch)
treed77b0a2f315ea48844e0933d6a471fdab2903f32 /src
parentd87d4ad3a6d2ec2804cb7b20128fecb6d9df4e6e (diff)
Fix more type annotations in rewriter
Use consistent nesting of tuples when adding updated local mutable variables to expressions. Add test case.
Diffstat (limited to 'src')
-rw-r--r--src/rewrites.ml134
1 files changed, 50 insertions, 84 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 454fefd3..151a63ff 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -2545,14 +2545,6 @@ let rewrite_defs_pat_lits =
* internal let-expressions, or internal plet-expressions ended by a term that does not
* access memory or registers and does not update variables *)
-let dedup eq =
- List.fold_left (fun acc e -> if List.exists (eq e) acc then acc else e :: acc) []
-
-let eqidtyp (id1,_) (id2,_) =
- let name1 = match id1 with Id_aux ((Id name | DeIid name),_) -> name in
- let name2 = match id2 with Id_aux ((Id name | DeIid name),_) -> name in
- name1 = name2
-
let find_introduced_vars exp =
let lEXP_aux ((ids, lexp), annot) =
let ids = match lexp with
@@ -2569,11 +2561,11 @@ let find_updated_vars exp =
let ids = match lexp with
| LEXP_id id | LEXP_cast (_, id)
when id_is_local_var id (env_of_annot annot) && not (IdSet.mem id intros) ->
- (id, annot) :: ids
+ IdSet.add id ids
| _ -> ids in
(ids, LEXP_aux (lexp, annot)) in
- dedup eqidtyp (fst (fold_exp
- { (compute_exp_alg [] (@)) with lEXP_aux = lEXP_aux } exp))
+ fst (fold_exp
+ { (compute_exp_alg IdSet.empty IdSet.union) with lEXP_aux = lEXP_aux } exp)
let swaptyp typ (l,tannot) = match tannot with
| Some (env, typ', eff) -> (l, Some (env, typ, eff))
@@ -2587,6 +2579,20 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let env = env_of exp in
+ let tuple_exp = function
+ | [] -> annot_exp (E_lit (mk_lit L_unit)) l env unit_typ
+ | [e] -> e
+ | es -> annot_exp (E_tuple es) l env (tuple_typ (List.map typ_of es)) in
+
+ let tuple_pat = function
+ | [] -> annot_pat P_wild l env unit_typ
+ | [pat] ->
+ let typ = pat_typ_of pat in
+ annot_pat (P_typ (typ, pat)) l env typ
+ | pats ->
+ let typ = tuple_typ (List.map pat_typ_of pats) in
+ annot_pat (P_typ (typ, annot_pat (P_tup pats) l env typ)) l env typ in
+
let rec add_vars overwrite ((E_aux (expaux,annot)) as exp) vars =
match expaux with
| E_let (lb,exp) ->
@@ -2607,48 +2613,19 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
and where it is used. *)
if overwrite then
match typ_of exp with
- | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> vars
+ | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> tuple_exp vars
| _ -> raise (Reporting_basic.err_unreachable l
"add_vars: trying to overwrite a non-unit expression in tail-position")
- else
- let typ' = Typ_aux (Typ_tup [typ_of exp;typ_of vars], gen_loc l) in
- E_aux (E_tuple [exp;vars],swaptyp typ' annot) in
-
- let mk_varstup l env es =
- let exp_to_pat (E_aux (eaux, annot) as exp) = match eaux with
- | E_lit lit ->
- P_aux (P_lit lit, annot)
- | E_id id ->
- annot_pat (P_id id) l (env_of exp) (typ_of exp)
- | _ -> raise (Reporting_basic.err_unreachable l
- ("Failed to extract pattern from expression " ^ string_of_exp exp)) in
- match es with
- | [] ->
- annot_exp (E_lit (mk_lit L_unit)) (gen_loc l) Env.empty unit_typ, [], []
- | [e] ->
- let e = infer_exp env (strip_exp e) in
- let typ = typ_of e in
- e, [annot_pat (P_typ (typ, exp_to_pat e)) l env typ], [typ_of e]
- | e :: _ ->
- let infer_e e = infer_exp env (strip_exp e) in
- let es = List.map infer_e es in
- let pats = List.map exp_to_pat es in
- let typ = tuple_typ (List.map typ_of es) in
- annot_exp (E_tuple es) l env typ, pats, List.map typ_of es in
-
- let add_vars_pat overwrite l env pat vartyps varpats =
- let typ, pat = match pat with
- | P_aux (P_typ (typ, pat), _) -> typ, pat
- | pat -> pat_typ_of pat, pat in
- let typs, pats =
- if overwrite then vartyps, varpats
- else typ :: vartyps, pat :: varpats in
- match typs, pats with
- | [], [] -> annot_pat P_wild l env unit_typ
- | [typ], [pat] -> annot_pat (P_typ (typ, pat)) l env typ
- | _, _ ->
- let tup_typ = tuple_typ typs in
- annot_pat (P_typ (tup_typ, annot_pat (P_tup pats) l env typ)) l env tup_typ in
+ else tuple_exp (exp :: vars) in
+
+ let mk_var_exps_pats l env ids =
+ ids
+ |> IdSet.elements
+ |> List.map
+ (fun id ->
+ let (E_aux (_, a) as exp) = infer_exp env (E_aux (E_id id, (l, ()))) in
+ exp, P_aux (P_id id, a))
+ |> List.split in
let rewrite (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (_,(pl,pannot)) as pat) =
let env = env_of_annot annot in
@@ -2666,10 +2643,8 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
function as an expression followed by the list of variables it
expects. In (Lem) pretty-printing, this turned into an anonymous
function and passed to foreach*. *)
- let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars exp4) in
- let varstuple, varpats, vartyps = mk_varstup el env vars in
- let varstyp = typ_of varstuple in
- let exp4 = rewrite_var_updates (add_vars overwrite exp4 varstuple) in
+ let vars, varpats = mk_var_exps_pats pl env (find_updated_vars exp4) in
+ let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in
let ord_exp, lower, upper = match destruct_range (typ_of exp1), destruct_range (typ_of exp2) with
| None, _ | _, None ->
raise (Reporting_basic.err_unreachable el "Could not determine loop bounds")
@@ -2685,47 +2660,41 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
lvar_kid)) el env lvar_typ) in
let lb = annot_letbind (lvar_pat, exp1) el env lvar_typ in
let body = annot_exp (E_let (lb, exp4)) el env (typ_of exp4) in
- let v = annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; varstuple; body])) el env (typ_of body) in
- let pat = add_vars_pat overwrite pl env pat vartyps varpats in
- Added_vars (v,pat)
+ let v = annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; tuple_exp vars; body])) el env (typ_of body) in
+ Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
| E_loop(loop,cond,body) ->
- let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars body) in
- let varstuple, varpats, vartyps = mk_varstup el env vars in
- let varstyp = typ_of varstuple in
- (* let cond = rewrite_var_updates (add_vars false cond varstuple) in *)
- let body = rewrite_var_updates (add_vars overwrite body varstuple) in
+ let vars, varpats = mk_var_exps_pats pl env (find_updated_vars body) in
+ let body = rewrite_var_updates (add_vars overwrite body vars) in
let (E_aux (_,(_,bannot))) = body in
let fname = match loop with
| While -> "while"
| Until -> "until" in
let funcl = Id_aux (Id fname,gen_loc el) in
- let v = E_aux (E_app (funcl,[cond;varstuple;body]), (gen_loc el, bannot)) in
- let pat = add_vars_pat overwrite pl env pat vartyps varpats in
- Added_vars (v,pat)
+ let v = E_aux (E_app (funcl,[cond;tuple_exp vars;body]), (gen_loc el, bannot)) in
+ Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
| E_if (c,e1,e2) ->
- let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t)))
- (dedup eqidtyp (find_updated_vars e1 @ find_updated_vars e2)) in
+ let vars, varpats =
+ IdSet.union (find_updated_vars e1) (find_updated_vars e2)
+ |> mk_var_exps_pats pl env in
if vars = [] then
(Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot)))
else
- let varstuple, varpats, vartyps = mk_varstup el env vars in
- let varstyp = typ_of varstuple in
- let e1 = rewrite_var_updates (add_vars overwrite e1 varstuple) in
- let e2 = rewrite_var_updates (add_vars overwrite e2 varstuple) in
+ let e1 = rewrite_var_updates (add_vars overwrite e1 vars) in
+ let e2 = rewrite_var_updates (add_vars overwrite e2 vars) in
(* after rewrite_defs_letbind_effects c has no variable updates *)
let env = env_of_annot annot in
let typ = typ_of e1 in
let eff = union_eff_exps [e1;e2] in
let v = E_aux (E_if (c,e1,e2), (gen_loc el, Some (env, typ, eff))) in
- let pat = add_vars_pat overwrite pl env pat vartyps varpats in
- Added_vars (v,pat)
+ Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
| E_case (e1,ps) ->
(* after rewrite_defs_letbind_effects e1 needs no rewriting *)
- let vars =
- let f acc (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) =
- acc @ find_updated_vars e in
- List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t)))
- (dedup eqidtyp (List.fold_left f [] ps)) in
+ let vars, varpats =
+ ps
+ |> List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e)
+ |> List.map find_updated_vars
+ |> List.fold_left IdSet.union IdSet.empty
+ |> mk_var_exps_pats pl env in
if vars = [] then
let ps = List.map (function
| Pat_aux (Pat_exp (p,e),a) ->
@@ -2734,11 +2703,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in
Same_vars (E_aux (E_case (e1,ps),annot))
else
- let varstuple, varpats, vartyps = mk_varstup el env vars in
- let varstyp = typ_of varstuple in
let rewrite_pexp (Pat_aux (pexp, (l, _))) = match pexp with
| Pat_exp (pat, exp) ->
- let exp = rewrite_var_updates (add_vars overwrite exp varstuple) in
+ let exp = rewrite_var_updates (add_vars overwrite exp vars) in
let pannot = (l, Some (env_of exp, typ_of exp, effect_of exp)) in
Pat_aux (Pat_exp (pat, exp), pannot)
| Pat_when _ ->
@@ -2748,8 +2715,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_) :: _ -> typ_of first
| _ -> unit_typ in
let v = propagate_exp_effect (annot_exp (E_case (e1, List.map rewrite_pexp ps)) pl env typ) in
- let pat = add_vars_pat overwrite pl env pat vartyps varpats in
- Added_vars (v,pat)
+ Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
| E_assign (lexp,vexp) ->
let mk_id_pat id = match Env.lookup_id id env with
| Local (_, typ) ->