diff options
Diffstat (limited to 'src/rewrites.ml')
| -rw-r--r-- | src/rewrites.ml | 140 |
1 files changed, 101 insertions, 39 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml index 3f6f95f4..b7ebd073 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -127,7 +127,7 @@ let rec lexp_is_local (LEXP_aux (lexp, _)) env = match lexp with | LEXP_memory _ | LEXP_deref _ -> false | LEXP_id id | LEXP_cast (_, id) -> id_is_local_var id env - | LEXP_tup lexps -> List.for_all (fun lexp -> lexp_is_local lexp env) lexps + | LEXP_tup lexps | LEXP_vector_concat lexps -> List.for_all (fun lexp -> lexp_is_local lexp env) lexps | LEXP_vector (lexp,_) | LEXP_vector_range (lexp,_,_) | LEXP_field (lexp,_) -> lexp_is_local lexp env @@ -136,7 +136,7 @@ let rec lexp_is_local_intro (LEXP_aux (lexp, _)) env = match lexp with | LEXP_memory _ | LEXP_deref _ -> false | LEXP_id id | LEXP_cast (_, id) -> id_is_unbound id env - | LEXP_tup lexps -> List.for_all (fun lexp -> lexp_is_local_intro lexp env) lexps + | LEXP_tup lexps | LEXP_vector_concat lexps -> List.for_all (fun lexp -> lexp_is_local_intro lexp env) lexps | LEXP_vector (lexp,_) | LEXP_vector_range (lexp,_,_) | LEXP_field (lexp,_) -> lexp_is_local_intro lexp env @@ -190,16 +190,18 @@ let lookup_equal_kids env = List.fold_left add_nc KBindings.empty (Env.get_constraints env) let lookup_constant_kid env kid = - try - let kids = KBindings.find kid (lookup_equal_kids env) in - let check_nc const nc = match const, nc with - | None, NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant i, _)), _) - when KidSet.mem kid kids -> - Some i - | _, _ -> const - in - List.fold_left check_nc None (Env.get_constraints env) - with Not_found -> None + let kids = + match KBindings.find kid (lookup_equal_kids env) with + | kids -> kids + | exception Not_found -> KidSet.singleton kid + in + let check_nc const nc = match const, nc with + | None, NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant i, _)), _) + when KidSet.mem kid kids -> + Some i + | _, _ -> const + in + List.fold_left check_nc None (Env.get_constraints env) let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = match nexp with | Nexp_id id -> rewrite_nexp_ids env (Env.get_num_def id env) @@ -241,8 +243,19 @@ let rewrite_defs_nexp_ids, rewrite_typ_nexp_ids = | (l, None) -> (l, None) in - rewrite_defs_base { - rewriters_base with rewrite_exp = (fun _ -> map_exp_annot rewrite_annot) + let rewrite_def rewriters = function + | DEF_spec (VS_aux (VS_val_spec (typschm, id, exts, b), (l, Some (env, typ, eff)))) -> + let typschm = match typschm with + | TypSchm_aux (TypSchm_ts (tq, typ), l) -> + TypSchm_aux (TypSchm_ts (tq, rewrite_typ env typ), l) + in + let a = rewrite_annot (l, Some (env, typ, eff)) in + DEF_spec (VS_aux (VS_val_spec (typschm, id, exts, b), a)) + | d -> Rewriter.rewrite_def rewriters d + in + + rewrite_defs_base { rewriters_base with + rewrite_exp = (fun _ -> map_exp_annot rewrite_annot); rewrite_def = rewrite_def }, rewrite_typ @@ -281,7 +294,7 @@ let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp = scope. *) | Some size when prove env (nc_eq (nsum size (nint 1)) nexp) -> let one_exp = infer_exp env (mk_lit_exp (L_num (Big_int.of_int 1))) in - Some (E_aux (E_app (mk_id "add_range", [var; one_exp]), (gen_loc l, Some (env, atom_typ (nsum size (nint 1)), no_effect)))) + Some (E_aux (E_app (mk_id "add_atom", [var; one_exp]), (gen_loc l, Some (env, atom_typ (nsum size (nint 1)), no_effect)))) | _ -> begin match destruct_vector env typ with @@ -293,12 +306,12 @@ let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp = let rec split_nexp (Nexp_aux (nexp_aux, l) as nexp) = match nexp_aux with | Nexp_sum (n1, n2) -> - mk_exp (E_app (mk_id "add_range", [split_nexp n1; split_nexp n2])) + mk_exp (E_app (mk_id "add_atom", [split_nexp n1; split_nexp n2])) | Nexp_minus (n1, n2) -> - mk_exp (E_app (mk_id "sub_range", [split_nexp n1; split_nexp n2])) + mk_exp (E_app (mk_id "sub_atom", [split_nexp n1; split_nexp n2])) | Nexp_times (n1, n2) -> - mk_exp (E_app (mk_id "mult_range", [split_nexp n1; split_nexp n2])) - | Nexp_neg nexp -> mk_exp (E_app (mk_id "negate_range", [split_nexp nexp])) + mk_exp (E_app (mk_id "mult_atom", [split_nexp n1; split_nexp n2])) + | Nexp_neg nexp -> mk_exp (E_app (mk_id "negate_atom", [split_nexp nexp])) | _ -> mk_exp (E_sizeof nexp) in let rec rewrite_e_aux split_sizeof (E_aux (e_aux, (l, _)) as orig_exp) = @@ -487,6 +500,7 @@ let rewrite_sizeof (Defs defs) = ; lEXP_tup = (fun tups -> let (tups,tups') = List.split tups in (LEXP_tup tups, LEXP_tup tups')) ; lEXP_vector = (fun ((lexp,lexp'),(e2,e2')) -> (LEXP_vector (lexp,e2), LEXP_vector (lexp',e2'))) ; lEXP_vector_range = (fun ((lexp,lexp'),(e2,e2'),(e3,e3')) -> (LEXP_vector_range (lexp,e2,e3), LEXP_vector_range (lexp',e2',e3'))) + ; lEXP_vector_concat = (fun lexps -> let (lexps,lexps') = List.split lexps in (LEXP_vector_concat lexps, LEXP_vector_concat lexps')) ; lEXP_field = (fun ((lexp,lexp'),id) -> (LEXP_field (lexp,id), LEXP_field (lexp',id))) ; lEXP_aux = (fun ((lexp,lexp'),annot) -> (LEXP_aux (lexp,annot), LEXP_aux (lexp',annot))) ; fE_Fexp = (fun (id,(e,e')) -> (FE_Fexp (id,e), FE_Fexp (id,e'))) @@ -1909,6 +1923,16 @@ let is_funcl_rec (FCL_aux (FCL_Funcl (id, pexp), _)) = E_app_infix (e1, f, e2))) } exp) + +let pat_var (P_aux (paux, a)) = + let env = env_of_annot a in + let is_var id = + not (Env.is_union_constructor id env) && + match Env.lookup_id id env with Enum _ -> false | _ -> true + in match paux with + | (P_as (_, id) | P_id id) when is_var id -> Some id + | _ -> None + (* Split out function clauses for individual union constructor patterns (e.g. AST nodes) into auxiliary functions. Used for the execute function. *) let rewrite_split_fun_constr_pats fun_name (Defs defs) = @@ -1933,10 +1957,10 @@ let rewrite_split_fun_constr_pats fun_name (Defs defs) = Bindings.add aux_fun_id (aux_clauses @ [aux_funcl]) aux_funs with Not_found -> let argpats, argexps = List.split (List.mapi - (fun idx (P_aux (paux, a)) -> - let id = match paux with - | P_as (_, id) | P_id id -> id - | _ -> mk_id ("arg" ^ string_of_int idx) + (fun idx (P_aux (_,a) as pat) -> + let id = match pat_var pat with + | Some id -> id + | None -> mk_id ("arg" ^ string_of_int idx) in P_aux (P_id id, a), E_aux (E_id id, a)) args) @@ -2290,11 +2314,11 @@ let rewrite_simple_types (Defs defs) = let defs = Defs (List.map simple_def defs) in rewrite_defs_base simple_defs defs -let rewrite_tuple_vector_assignments defs = +let rewrite_vector_concat_assignments defs = let assign_tuple e_aux annot = let env = env_of_annot annot in match e_aux with - | E_assign (LEXP_aux (LEXP_tup lexps, lannot), exp) -> + | E_assign (LEXP_aux (LEXP_vector_concat lexps, lannot), exp) -> let typ = Env.base_typ_of env (typ_of exp) in if is_vector_typ typ then (* let _ = Pretty_print_common.print stderr (Pretty_print_sail.doc_exp (E_aux (e_aux, annot))) in *) @@ -2527,8 +2551,11 @@ let rewrite_defs_letbind_effects = | LEXP_vector_range (lexp,e1,e2) -> n_lexp lexp (fun lexp -> n_exp_name e1 (fun e1 -> - n_exp_name e2 (fun e2 -> + n_exp_name e2 (fun e2 -> k (fix_eff_lexp (LEXP_aux (LEXP_vector_range (lexp,e1,e2),annot)))))) + | LEXP_vector_concat es -> + n_lexpL es (fun es -> + k (fix_eff_lexp (LEXP_aux (LEXP_vector_concat es,annot)))) | LEXP_field (lexp,id) -> n_lexp lexp (fun lexp -> k (fix_eff_lexp (LEXP_aux (LEXP_field (lexp,id),annot)))) @@ -2563,6 +2590,14 @@ let rewrite_defs_letbind_effects = | E_cast (typ,exp') -> n_exp_name exp' (fun exp' -> k (rewrap (E_cast (typ,exp')))) + | E_app (op_bool, [l; r]) + when string_of_id op_bool = "and_bool" || string_of_id op_bool = "or_bool" -> + (* Leave effectful operands of Boolean "and"/"or" in place to allow + short-circuiting. *) + let newreturn = effectful l || effectful r in + let l = n_exp_term newreturn l in + let r = n_exp_term newreturn r in + k (rewrap (E_app (op_bool, [l; r]))) | E_app (id,exps) -> n_exp_nameL exps (fun exps -> k (rewrap (E_app (id,exps)))) @@ -3203,20 +3238,21 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = |> mk_var_exps_pats pl env in let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in - let ord_exp, kids, constr, lower, upper = - match destruct_range env (typ_of exp1), destruct_range env (typ_of exp2) with + let ord_exp, kids, constr, lower, upper, lower_exp, upper_exp = + match destruct_numeric env (typ_of exp1), destruct_numeric env (typ_of exp2) with | None, _ | _, None -> raise (Reporting_basic.err_unreachable el "Could not determine loop bounds") - | Some (kids1, constr1, l1, u1), Some (kids2, constr2, l2, u2) -> + | Some (kids1, constr1, n1), Some (kids2, constr2, n2) -> let kids = kids1 @ kids2 in let constr = nc_and constr1 constr2 in - let ord_exp, lower, upper = + let ord_exp, lower, upper, lower_exp, upper_exp = if is_order_inc order - then (annot_exp (E_lit (mk_lit L_true)) el env bool_typ, l1, u2) - else (annot_exp (E_lit (mk_lit L_false)) el env bool_typ, l2, u1) + then (annot_exp (E_lit (mk_lit L_true)) el env bool_typ, n1, n2, exp1, exp2) + else (annot_exp (E_lit (mk_lit L_false)) el env bool_typ, n2, n1, exp2, exp1) in - ord_exp, kids, constr, lower, upper + ord_exp, kids, constr, lower, upper, lower_exp, upper_exp in + (* Bind the loop variable in the body, annotated with constraints *) let lvar_kid = mk_kid ("loop_" ^ string_of_id id) in let lvar_nc = nc_and constr (nc_and (nc_lteq lower (nvar lvar_kid)) (nc_lteq (nvar lvar_kid) upper)) in let lvar_typ = mk_typ (Typ_exist (lvar_kid :: kids, lvar_nc, atom_typ (nvar lvar_kid))) in @@ -3225,7 +3261,33 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = TP_aux (TP_var lvar_kid, gen_loc el))) el env lvar_typ)) in let lb = fix_eff_lb (annot_letbind (lvar_pat, exp1) el env lvar_typ) in let body = fix_eff_exp (annot_exp (E_let (lb, exp4)) el env (typ_of exp4)) in - let v = fix_eff_exp (annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; tuple_exp vars; body])) el env (typ_of body)) in + (* If lower > upper, the loop body never gets executed, and the type + checker might not be able to prove that the initial value exp1 + satisfies the constraints on the loop variable. + + Make this explicit by guarding the loop body with lower <= upper. + (for type-checking; the guard is later removed again by the Lem + pretty-printer). This could be implemented with an assertion, but + that would force the loop to be effectful, so we use an if-expression + instead. This code assumes that the loop bounds have (possibly + existential) atom types, and the loop body has type unit. *) + let lower_kid = mk_kid ("loop_" ^ string_of_id id ^ "_lower") in + let lower_pat = P_var (annot_pat P_wild el env (typ_of lower_exp), mk_typ_pat (TP_app (mk_id "atom", [mk_typ_pat (TP_var lower_kid)]))) in + let lb_lower = annot_letbind (lower_pat, lower_exp) el env (typ_of lower_exp) in + let upper_kid = mk_kid ("loop_" ^ string_of_id id ^ "_upper") in + let upper_pat = P_var (annot_pat P_wild el env (typ_of upper_exp), mk_typ_pat (TP_app (mk_id "atom", [mk_typ_pat (TP_var upper_kid)]))) in + let lb_upper = annot_letbind (upper_pat, upper_exp) el env (typ_of upper_exp) in + let guard = annot_exp (E_constraint (nc_lteq (nvar lower_kid) (nvar upper_kid))) el env bool_typ in + let unit_exp = annot_exp (E_lit (mk_lit L_unit)) el env unit_typ in + let skip_val = tuple_exp (if overwrite then vars else unit_exp :: vars) in + let guarded_body = + fix_eff_exp (annot_exp (E_let (lb_lower, + fix_eff_exp (annot_exp (E_let (lb_upper, + fix_eff_exp (annot_exp (E_if (guard, body, skip_val)) + el env (typ_of exp4)))) + el env (typ_of exp4)))) + el env (typ_of exp4)) in + let v = fix_eff_exp (annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; tuple_exp vars; guarded_body])) el env (typ_of body)) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_loop(loop,cond,body) -> let vars, varpats = @@ -3254,7 +3316,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = (* after rewrite_defs_letbind_effects c has no variable updates *) let env = env_of_annot annot in let typ = typ_of e1 in - let eff = union_eff_exps [e1;e2] in + let eff = union_eff_exps [c;e1;e2] in let v = E_aux (E_if (c,e1,e2), (gen_loc el, Some (env, typ, eff))) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_case (e1,ps) -> @@ -3712,7 +3774,7 @@ let recheck_defs defs = fst (check initial_env defs) let rewrite_defs_lem = [ ("realise_mappings", rewrite_defs_realise_mappings); - ("tuple_vector_assignments", rewrite_tuple_vector_assignments); + ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); ("remove_vector_concat", rewrite_defs_remove_vector_concat); @@ -3752,7 +3814,7 @@ let rewrite_defs_ocaml = [ ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_builtins); ("pat_lits", rewrite_defs_pat_lits); - ("tuple_vector_assignments", rewrite_tuple_vector_assignments); + ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); ("remove_vector_concat", rewrite_defs_remove_vector_concat); @@ -3774,7 +3836,7 @@ let rewrite_defs_c = [ ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_builtins); ("pat_lits", rewrite_defs_pat_lits); - ("tuple_vector_assignments", rewrite_tuple_vector_assignments); + ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); ("remove_vector_concat", rewrite_defs_remove_vector_concat); @@ -3793,7 +3855,7 @@ let rewrite_defs_interpreter = [ ("realise_mappings", rewrite_defs_realise_mappings); ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_builtins); - ("tuple_vector_assignments", rewrite_tuple_vector_assignments); + ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); ("remove_vector_concat", rewrite_defs_remove_vector_concat); |
