diff options
| author | Thomas Bauereiss | 2018-01-25 11:18:10 +0000 |
|---|---|---|
| committer | Thomas Bauereiss | 2018-01-25 11:41:22 +0000 |
| commit | 54d18f2d19f33aae822dca53485afa8ba9e06e81 (patch) | |
| tree | d77b0a2f315ea48844e0933d6a471fdab2903f32 /src | |
| parent | d87d4ad3a6d2ec2804cb7b20128fecb6d9df4e6e (diff) | |
Fix more type annotations in rewriter
Use consistent nesting of tuples when adding updated local mutable variables to
expressions. Add test case.
Diffstat (limited to 'src')
| -rw-r--r-- | src/rewrites.ml | 134 |
1 files changed, 50 insertions, 84 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index 454fefd3..151a63ff 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2545,14 +2545,6 @@ let rewrite_defs_pat_lits = * internal let-expressions, or internal plet-expressions ended by a term that does not * access memory or registers and does not update variables *) -let dedup eq = - List.fold_left (fun acc e -> if List.exists (eq e) acc then acc else e :: acc) [] - -let eqidtyp (id1,_) (id2,_) = - let name1 = match id1 with Id_aux ((Id name | DeIid name),_) -> name in - let name2 = match id2 with Id_aux ((Id name | DeIid name),_) -> name in - name1 = name2 - let find_introduced_vars exp = let lEXP_aux ((ids, lexp), annot) = let ids = match lexp with @@ -2569,11 +2561,11 @@ let find_updated_vars exp = let ids = match lexp with | LEXP_id id | LEXP_cast (_, id) when id_is_local_var id (env_of_annot annot) && not (IdSet.mem id intros) -> - (id, annot) :: ids + IdSet.add id ids | _ -> ids in (ids, LEXP_aux (lexp, annot)) in - dedup eqidtyp (fst (fold_exp - { (compute_exp_alg [] (@)) with lEXP_aux = lEXP_aux } exp)) + fst (fold_exp + { (compute_exp_alg IdSet.empty IdSet.union) with lEXP_aux = lEXP_aux } exp) let swaptyp typ (l,tannot) = match tannot with | Some (env, typ', eff) -> (l, Some (env, typ, eff)) @@ -2587,6 +2579,20 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let env = env_of exp in + let tuple_exp = function + | [] -> annot_exp (E_lit (mk_lit L_unit)) l env unit_typ + | [e] -> e + | es -> annot_exp (E_tuple es) l env (tuple_typ (List.map typ_of es)) in + + let tuple_pat = function + | [] -> annot_pat P_wild l env unit_typ + | [pat] -> + let typ = pat_typ_of pat in + annot_pat (P_typ (typ, pat)) l env typ + | pats -> + let typ = tuple_typ (List.map pat_typ_of pats) in + annot_pat (P_typ (typ, annot_pat (P_tup pats) l env typ)) l env typ in + let rec add_vars overwrite ((E_aux (expaux,annot)) as exp) vars = match expaux with | E_let (lb,exp) -> @@ -2607,48 +2613,19 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = and where it is used. *) if overwrite then match typ_of exp with - | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> vars + | Typ_aux (Typ_id (Id_aux (Id "unit", _)), _) -> tuple_exp vars | _ -> raise (Reporting_basic.err_unreachable l "add_vars: trying to overwrite a non-unit expression in tail-position") - else - let typ' = Typ_aux (Typ_tup [typ_of exp;typ_of vars], gen_loc l) in - E_aux (E_tuple [exp;vars],swaptyp typ' annot) in - - let mk_varstup l env es = - let exp_to_pat (E_aux (eaux, annot) as exp) = match eaux with - | E_lit lit -> - P_aux (P_lit lit, annot) - | E_id id -> - annot_pat (P_id id) l (env_of exp) (typ_of exp) - | _ -> raise (Reporting_basic.err_unreachable l - ("Failed to extract pattern from expression " ^ string_of_exp exp)) in - match es with - | [] -> - annot_exp (E_lit (mk_lit L_unit)) (gen_loc l) Env.empty unit_typ, [], [] - | [e] -> - let e = infer_exp env (strip_exp e) in - let typ = typ_of e in - e, [annot_pat (P_typ (typ, exp_to_pat e)) l env typ], [typ_of e] - | e :: _ -> - let infer_e e = infer_exp env (strip_exp e) in - let es = List.map infer_e es in - let pats = List.map exp_to_pat es in - let typ = tuple_typ (List.map typ_of es) in - annot_exp (E_tuple es) l env typ, pats, List.map typ_of es in - - let add_vars_pat overwrite l env pat vartyps varpats = - let typ, pat = match pat with - | P_aux (P_typ (typ, pat), _) -> typ, pat - | pat -> pat_typ_of pat, pat in - let typs, pats = - if overwrite then vartyps, varpats - else typ :: vartyps, pat :: varpats in - match typs, pats with - | [], [] -> annot_pat P_wild l env unit_typ - | [typ], [pat] -> annot_pat (P_typ (typ, pat)) l env typ - | _, _ -> - let tup_typ = tuple_typ typs in - annot_pat (P_typ (tup_typ, annot_pat (P_tup pats) l env typ)) l env tup_typ in + else tuple_exp (exp :: vars) in + + let mk_var_exps_pats l env ids = + ids + |> IdSet.elements + |> List.map + (fun id -> + let (E_aux (_, a) as exp) = infer_exp env (E_aux (E_id id, (l, ()))) in + exp, P_aux (P_id id, a)) + |> List.split in let rewrite (E_aux (expaux,((el,_) as annot)) as full_exp) (P_aux (_,(pl,pannot)) as pat) = let env = env_of_annot annot in @@ -2666,10 +2643,8 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = function as an expression followed by the list of variables it expects. In (Lem) pretty-printing, this turned into an anonymous function and passed to foreach*. *) - let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars exp4) in - let varstuple, varpats, vartyps = mk_varstup el env vars in - let varstyp = typ_of varstuple in - let exp4 = rewrite_var_updates (add_vars overwrite exp4 varstuple) in + let vars, varpats = mk_var_exps_pats pl env (find_updated_vars exp4) in + let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in let ord_exp, lower, upper = match destruct_range (typ_of exp1), destruct_range (typ_of exp2) with | None, _ | _, None -> raise (Reporting_basic.err_unreachable el "Could not determine loop bounds") @@ -2685,47 +2660,41 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = lvar_kid)) el env lvar_typ) in let lb = annot_letbind (lvar_pat, exp1) el env lvar_typ in let body = annot_exp (E_let (lb, exp4)) el env (typ_of exp4) in - let v = annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; varstuple; body])) el env (typ_of body) in - let pat = add_vars_pat overwrite pl env pat vartyps varpats in - Added_vars (v,pat) + let v = annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; tuple_exp vars; body])) el env (typ_of body) in + Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_loop(loop,cond,body) -> - let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) (find_updated_vars body) in - let varstuple, varpats, vartyps = mk_varstup el env vars in - let varstyp = typ_of varstuple in - (* let cond = rewrite_var_updates (add_vars false cond varstuple) in *) - let body = rewrite_var_updates (add_vars overwrite body varstuple) in + let vars, varpats = mk_var_exps_pats pl env (find_updated_vars body) in + let body = rewrite_var_updates (add_vars overwrite body vars) in let (E_aux (_,(_,bannot))) = body in let fname = match loop with | While -> "while" | Until -> "until" in let funcl = Id_aux (Id fname,gen_loc el) in - let v = E_aux (E_app (funcl,[cond;varstuple;body]), (gen_loc el, bannot)) in - let pat = add_vars_pat overwrite pl env pat vartyps varpats in - Added_vars (v,pat) + let v = E_aux (E_app (funcl,[cond;tuple_exp vars;body]), (gen_loc el, bannot)) in + Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_if (c,e1,e2) -> - let vars = List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) - (dedup eqidtyp (find_updated_vars e1 @ find_updated_vars e2)) in + let vars, varpats = + IdSet.union (find_updated_vars e1) (find_updated_vars e2) + |> mk_var_exps_pats pl env in if vars = [] then (Same_vars (E_aux (E_if (c,rewrite_var_updates e1,rewrite_var_updates e2),annot))) else - let varstuple, varpats, vartyps = mk_varstup el env vars in - let varstyp = typ_of varstuple in - let e1 = rewrite_var_updates (add_vars overwrite e1 varstuple) in - let e2 = rewrite_var_updates (add_vars overwrite e2 varstuple) in + let e1 = rewrite_var_updates (add_vars overwrite e1 vars) in + let e2 = rewrite_var_updates (add_vars overwrite e2 vars) in (* after rewrite_defs_letbind_effects c has no variable updates *) let env = env_of_annot annot in let typ = typ_of e1 in let eff = union_eff_exps [e1;e2] in let v = E_aux (E_if (c,e1,e2), (gen_loc el, Some (env, typ, eff))) in - let pat = add_vars_pat overwrite pl env pat vartyps varpats in - Added_vars (v,pat) + Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_case (e1,ps) -> (* after rewrite_defs_letbind_effects e1 needs no rewriting *) - let vars = - let f acc (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) = - acc @ find_updated_vars e in - List.map (fun (var,(l,t)) -> E_aux (E_id var,(l,t))) - (dedup eqidtyp (List.fold_left f [] ps)) in + let vars, varpats = + ps + |> List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e) + |> List.map find_updated_vars + |> List.fold_left IdSet.union IdSet.empty + |> mk_var_exps_pats pl env in if vars = [] then let ps = List.map (function | Pat_aux (Pat_exp (p,e),a) -> @@ -2734,11 +2703,9 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in Same_vars (E_aux (E_case (e1,ps),annot)) else - let varstuple, varpats, vartyps = mk_varstup el env vars in - let varstyp = typ_of varstuple in let rewrite_pexp (Pat_aux (pexp, (l, _))) = match pexp with | Pat_exp (pat, exp) -> - let exp = rewrite_var_updates (add_vars overwrite exp varstuple) in + let exp = rewrite_var_updates (add_vars overwrite exp vars) in let pannot = (l, Some (env_of exp, typ_of exp, effect_of exp)) in Pat_aux (Pat_exp (pat, exp), pannot) | Pat_when _ -> @@ -2748,8 +2715,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_) :: _ -> typ_of first | _ -> unit_typ in let v = propagate_exp_effect (annot_exp (E_case (e1, List.map rewrite_pexp ps)) pl env typ) in - let pat = add_vars_pat overwrite pl env pat vartyps varpats in - Added_vars (v,pat) + Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_assign (lexp,vexp) -> let mk_id_pat id = match Env.lookup_id id env with | Local (_, typ) -> |
