summaryrefslogtreecommitdiff
path: root/src/rewrites.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewrites.ml')
-rw-r--r--src/rewrites.ml140
1 files changed, 101 insertions, 39 deletions
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 3f6f95f4..b7ebd073 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -127,7 +127,7 @@ let rec lexp_is_local (LEXP_aux (lexp, _)) env = match lexp with
| LEXP_memory _ | LEXP_deref _ -> false
| LEXP_id id
| LEXP_cast (_, id) -> id_is_local_var id env
- | LEXP_tup lexps -> List.for_all (fun lexp -> lexp_is_local lexp env) lexps
+ | LEXP_tup lexps | LEXP_vector_concat lexps -> List.for_all (fun lexp -> lexp_is_local lexp env) lexps
| LEXP_vector (lexp,_)
| LEXP_vector_range (lexp,_,_)
| LEXP_field (lexp,_) -> lexp_is_local lexp env
@@ -136,7 +136,7 @@ let rec lexp_is_local_intro (LEXP_aux (lexp, _)) env = match lexp with
| LEXP_memory _ | LEXP_deref _ -> false
| LEXP_id id
| LEXP_cast (_, id) -> id_is_unbound id env
- | LEXP_tup lexps -> List.for_all (fun lexp -> lexp_is_local_intro lexp env) lexps
+ | LEXP_tup lexps | LEXP_vector_concat lexps -> List.for_all (fun lexp -> lexp_is_local_intro lexp env) lexps
| LEXP_vector (lexp,_)
| LEXP_vector_range (lexp,_,_)
| LEXP_field (lexp,_) -> lexp_is_local_intro lexp env
@@ -190,16 +190,18 @@ let lookup_equal_kids env =
List.fold_left add_nc KBindings.empty (Env.get_constraints env)
let lookup_constant_kid env kid =
- try
- let kids = KBindings.find kid (lookup_equal_kids env) in
- let check_nc const nc = match const, nc with
- | None, NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant i, _)), _)
- when KidSet.mem kid kids ->
- Some i
- | _, _ -> const
- in
- List.fold_left check_nc None (Env.get_constraints env)
- with Not_found -> None
+ let kids =
+ match KBindings.find kid (lookup_equal_kids env) with
+ | kids -> kids
+ | exception Not_found -> KidSet.singleton kid
+ in
+ let check_nc const nc = match const, nc with
+ | None, NC_aux (NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant i, _)), _)
+ when KidSet.mem kid kids ->
+ Some i
+ | _, _ -> const
+ in
+ List.fold_left check_nc None (Env.get_constraints env)
let rec rewrite_nexp_ids env (Nexp_aux (nexp, l) as nexp_aux) = match nexp with
| Nexp_id id -> rewrite_nexp_ids env (Env.get_num_def id env)
@@ -241,8 +243,19 @@ let rewrite_defs_nexp_ids, rewrite_typ_nexp_ids =
| (l, None) -> (l, None)
in
- rewrite_defs_base {
- rewriters_base with rewrite_exp = (fun _ -> map_exp_annot rewrite_annot)
+ let rewrite_def rewriters = function
+ | DEF_spec (VS_aux (VS_val_spec (typschm, id, exts, b), (l, Some (env, typ, eff)))) ->
+ let typschm = match typschm with
+ | TypSchm_aux (TypSchm_ts (tq, typ), l) ->
+ TypSchm_aux (TypSchm_ts (tq, rewrite_typ env typ), l)
+ in
+ let a = rewrite_annot (l, Some (env, typ, eff)) in
+ DEF_spec (VS_aux (VS_val_spec (typschm, id, exts, b), a))
+ | d -> Rewriter.rewrite_def rewriters d
+ in
+
+ rewrite_defs_base { rewriters_base with
+ rewrite_exp = (fun _ -> map_exp_annot rewrite_annot); rewrite_def = rewrite_def
},
rewrite_typ
@@ -281,7 +294,7 @@ let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp =
scope. *)
| Some size when prove env (nc_eq (nsum size (nint 1)) nexp) ->
let one_exp = infer_exp env (mk_lit_exp (L_num (Big_int.of_int 1))) in
- Some (E_aux (E_app (mk_id "add_range", [var; one_exp]), (gen_loc l, Some (env, atom_typ (nsum size (nint 1)), no_effect))))
+ Some (E_aux (E_app (mk_id "add_atom", [var; one_exp]), (gen_loc l, Some (env, atom_typ (nsum size (nint 1)), no_effect))))
| _ ->
begin
match destruct_vector env typ with
@@ -293,12 +306,12 @@ let rewrite_trivial_sizeof, rewrite_trivial_sizeof_exp =
let rec split_nexp (Nexp_aux (nexp_aux, l) as nexp) =
match nexp_aux with
| Nexp_sum (n1, n2) ->
- mk_exp (E_app (mk_id "add_range", [split_nexp n1; split_nexp n2]))
+ mk_exp (E_app (mk_id "add_atom", [split_nexp n1; split_nexp n2]))
| Nexp_minus (n1, n2) ->
- mk_exp (E_app (mk_id "sub_range", [split_nexp n1; split_nexp n2]))
+ mk_exp (E_app (mk_id "sub_atom", [split_nexp n1; split_nexp n2]))
| Nexp_times (n1, n2) ->
- mk_exp (E_app (mk_id "mult_range", [split_nexp n1; split_nexp n2]))
- | Nexp_neg nexp -> mk_exp (E_app (mk_id "negate_range", [split_nexp nexp]))
+ mk_exp (E_app (mk_id "mult_atom", [split_nexp n1; split_nexp n2]))
+ | Nexp_neg nexp -> mk_exp (E_app (mk_id "negate_atom", [split_nexp nexp]))
| _ -> mk_exp (E_sizeof nexp)
in
let rec rewrite_e_aux split_sizeof (E_aux (e_aux, (l, _)) as orig_exp) =
@@ -487,6 +500,7 @@ let rewrite_sizeof (Defs defs) =
; lEXP_tup = (fun tups -> let (tups,tups') = List.split tups in (LEXP_tup tups, LEXP_tup tups'))
; lEXP_vector = (fun ((lexp,lexp'),(e2,e2')) -> (LEXP_vector (lexp,e2), LEXP_vector (lexp',e2')))
; lEXP_vector_range = (fun ((lexp,lexp'),(e2,e2'),(e3,e3')) -> (LEXP_vector_range (lexp,e2,e3), LEXP_vector_range (lexp',e2',e3')))
+ ; lEXP_vector_concat = (fun lexps -> let (lexps,lexps') = List.split lexps in (LEXP_vector_concat lexps, LEXP_vector_concat lexps'))
; lEXP_field = (fun ((lexp,lexp'),id) -> (LEXP_field (lexp,id), LEXP_field (lexp',id)))
; lEXP_aux = (fun ((lexp,lexp'),annot) -> (LEXP_aux (lexp,annot), LEXP_aux (lexp',annot)))
; fE_Fexp = (fun (id,(e,e')) -> (FE_Fexp (id,e), FE_Fexp (id,e')))
@@ -1909,6 +1923,16 @@ let is_funcl_rec (FCL_aux (FCL_Funcl (id, pexp), _)) =
E_app_infix (e1, f, e2))) }
exp)
+
+let pat_var (P_aux (paux, a)) =
+ let env = env_of_annot a in
+ let is_var id =
+ not (Env.is_union_constructor id env) &&
+ match Env.lookup_id id env with Enum _ -> false | _ -> true
+ in match paux with
+ | (P_as (_, id) | P_id id) when is_var id -> Some id
+ | _ -> None
+
(* Split out function clauses for individual union constructor patterns
(e.g. AST nodes) into auxiliary functions. Used for the execute function. *)
let rewrite_split_fun_constr_pats fun_name (Defs defs) =
@@ -1933,10 +1957,10 @@ let rewrite_split_fun_constr_pats fun_name (Defs defs) =
Bindings.add aux_fun_id (aux_clauses @ [aux_funcl]) aux_funs
with Not_found ->
let argpats, argexps = List.split (List.mapi
- (fun idx (P_aux (paux, a)) ->
- let id = match paux with
- | P_as (_, id) | P_id id -> id
- | _ -> mk_id ("arg" ^ string_of_int idx)
+ (fun idx (P_aux (_,a) as pat) ->
+ let id = match pat_var pat with
+ | Some id -> id
+ | None -> mk_id ("arg" ^ string_of_int idx)
in
P_aux (P_id id, a), E_aux (E_id id, a))
args)
@@ -2290,11 +2314,11 @@ let rewrite_simple_types (Defs defs) =
let defs = Defs (List.map simple_def defs) in
rewrite_defs_base simple_defs defs
-let rewrite_tuple_vector_assignments defs =
+let rewrite_vector_concat_assignments defs =
let assign_tuple e_aux annot =
let env = env_of_annot annot in
match e_aux with
- | E_assign (LEXP_aux (LEXP_tup lexps, lannot), exp) ->
+ | E_assign (LEXP_aux (LEXP_vector_concat lexps, lannot), exp) ->
let typ = Env.base_typ_of env (typ_of exp) in
if is_vector_typ typ then
(* let _ = Pretty_print_common.print stderr (Pretty_print_sail.doc_exp (E_aux (e_aux, annot))) in *)
@@ -2527,8 +2551,11 @@ let rewrite_defs_letbind_effects =
| LEXP_vector_range (lexp,e1,e2) ->
n_lexp lexp (fun lexp ->
n_exp_name e1 (fun e1 ->
- n_exp_name e2 (fun e2 ->
+ n_exp_name e2 (fun e2 ->
k (fix_eff_lexp (LEXP_aux (LEXP_vector_range (lexp,e1,e2),annot))))))
+ | LEXP_vector_concat es ->
+ n_lexpL es (fun es ->
+ k (fix_eff_lexp (LEXP_aux (LEXP_vector_concat es,annot))))
| LEXP_field (lexp,id) ->
n_lexp lexp (fun lexp ->
k (fix_eff_lexp (LEXP_aux (LEXP_field (lexp,id),annot))))
@@ -2563,6 +2590,14 @@ let rewrite_defs_letbind_effects =
| E_cast (typ,exp') ->
n_exp_name exp' (fun exp' ->
k (rewrap (E_cast (typ,exp'))))
+ | E_app (op_bool, [l; r])
+ when string_of_id op_bool = "and_bool" || string_of_id op_bool = "or_bool" ->
+ (* Leave effectful operands of Boolean "and"/"or" in place to allow
+ short-circuiting. *)
+ let newreturn = effectful l || effectful r in
+ let l = n_exp_term newreturn l in
+ let r = n_exp_term newreturn r in
+ k (rewrap (E_app (op_bool, [l; r])))
| E_app (id,exps) ->
n_exp_nameL exps (fun exps ->
k (rewrap (E_app (id,exps))))
@@ -3203,20 +3238,21 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
|> mk_var_exps_pats pl env
in
let exp4 = rewrite_var_updates (add_vars overwrite exp4 vars) in
- let ord_exp, kids, constr, lower, upper =
- match destruct_range env (typ_of exp1), destruct_range env (typ_of exp2) with
+ let ord_exp, kids, constr, lower, upper, lower_exp, upper_exp =
+ match destruct_numeric env (typ_of exp1), destruct_numeric env (typ_of exp2) with
| None, _ | _, None ->
raise (Reporting_basic.err_unreachable el "Could not determine loop bounds")
- | Some (kids1, constr1, l1, u1), Some (kids2, constr2, l2, u2) ->
+ | Some (kids1, constr1, n1), Some (kids2, constr2, n2) ->
let kids = kids1 @ kids2 in
let constr = nc_and constr1 constr2 in
- let ord_exp, lower, upper =
+ let ord_exp, lower, upper, lower_exp, upper_exp =
if is_order_inc order
- then (annot_exp (E_lit (mk_lit L_true)) el env bool_typ, l1, u2)
- else (annot_exp (E_lit (mk_lit L_false)) el env bool_typ, l2, u1)
+ then (annot_exp (E_lit (mk_lit L_true)) el env bool_typ, n1, n2, exp1, exp2)
+ else (annot_exp (E_lit (mk_lit L_false)) el env bool_typ, n2, n1, exp2, exp1)
in
- ord_exp, kids, constr, lower, upper
+ ord_exp, kids, constr, lower, upper, lower_exp, upper_exp
in
+ (* Bind the loop variable in the body, annotated with constraints *)
let lvar_kid = mk_kid ("loop_" ^ string_of_id id) in
let lvar_nc = nc_and constr (nc_and (nc_lteq lower (nvar lvar_kid)) (nc_lteq (nvar lvar_kid) upper)) in
let lvar_typ = mk_typ (Typ_exist (lvar_kid :: kids, lvar_nc, atom_typ (nvar lvar_kid))) in
@@ -3225,7 +3261,33 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
TP_aux (TP_var lvar_kid, gen_loc el))) el env lvar_typ)) in
let lb = fix_eff_lb (annot_letbind (lvar_pat, exp1) el env lvar_typ) in
let body = fix_eff_exp (annot_exp (E_let (lb, exp4)) el env (typ_of exp4)) in
- let v = fix_eff_exp (annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; tuple_exp vars; body])) el env (typ_of body)) in
+ (* If lower > upper, the loop body never gets executed, and the type
+ checker might not be able to prove that the initial value exp1
+ satisfies the constraints on the loop variable.
+
+ Make this explicit by guarding the loop body with lower <= upper.
+ (for type-checking; the guard is later removed again by the Lem
+ pretty-printer). This could be implemented with an assertion, but
+ that would force the loop to be effectful, so we use an if-expression
+ instead. This code assumes that the loop bounds have (possibly
+ existential) atom types, and the loop body has type unit. *)
+ let lower_kid = mk_kid ("loop_" ^ string_of_id id ^ "_lower") in
+ let lower_pat = P_var (annot_pat P_wild el env (typ_of lower_exp), mk_typ_pat (TP_app (mk_id "atom", [mk_typ_pat (TP_var lower_kid)]))) in
+ let lb_lower = annot_letbind (lower_pat, lower_exp) el env (typ_of lower_exp) in
+ let upper_kid = mk_kid ("loop_" ^ string_of_id id ^ "_upper") in
+ let upper_pat = P_var (annot_pat P_wild el env (typ_of upper_exp), mk_typ_pat (TP_app (mk_id "atom", [mk_typ_pat (TP_var upper_kid)]))) in
+ let lb_upper = annot_letbind (upper_pat, upper_exp) el env (typ_of upper_exp) in
+ let guard = annot_exp (E_constraint (nc_lteq (nvar lower_kid) (nvar upper_kid))) el env bool_typ in
+ let unit_exp = annot_exp (E_lit (mk_lit L_unit)) el env unit_typ in
+ let skip_val = tuple_exp (if overwrite then vars else unit_exp :: vars) in
+ let guarded_body =
+ fix_eff_exp (annot_exp (E_let (lb_lower,
+ fix_eff_exp (annot_exp (E_let (lb_upper,
+ fix_eff_exp (annot_exp (E_if (guard, body, skip_val))
+ el env (typ_of exp4))))
+ el env (typ_of exp4))))
+ el env (typ_of exp4)) in
+ let v = fix_eff_exp (annot_exp (E_app (mk_id "foreach", [exp1; exp2; exp3; ord_exp; tuple_exp vars; guarded_body])) el env (typ_of body)) in
Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
| E_loop(loop,cond,body) ->
let vars, varpats =
@@ -3254,7 +3316,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
(* after rewrite_defs_letbind_effects c has no variable updates *)
let env = env_of_annot annot in
let typ = typ_of e1 in
- let eff = union_eff_exps [e1;e2] in
+ let eff = union_eff_exps [c;e1;e2] in
let v = E_aux (E_if (c,e1,e2), (gen_loc el, Some (env, typ, eff))) in
Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
| E_case (e1,ps) ->
@@ -3712,7 +3774,7 @@ let recheck_defs defs = fst (check initial_env defs)
let rewrite_defs_lem = [
("realise_mappings", rewrite_defs_realise_mappings);
- ("tuple_vector_assignments", rewrite_tuple_vector_assignments);
+ ("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
("remove_vector_concat", rewrite_defs_remove_vector_concat);
@@ -3752,7 +3814,7 @@ let rewrite_defs_ocaml = [
("pat_string_append", rewrite_defs_pat_string_append);
("mapping_builtins", rewrite_defs_mapping_builtins);
("pat_lits", rewrite_defs_pat_lits);
- ("tuple_vector_assignments", rewrite_tuple_vector_assignments);
+ ("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
("remove_vector_concat", rewrite_defs_remove_vector_concat);
@@ -3774,7 +3836,7 @@ let rewrite_defs_c = [
("pat_string_append", rewrite_defs_pat_string_append);
("mapping_builtins", rewrite_defs_mapping_builtins);
("pat_lits", rewrite_defs_pat_lits);
- ("tuple_vector_assignments", rewrite_tuple_vector_assignments);
+ ("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
("remove_vector_concat", rewrite_defs_remove_vector_concat);
@@ -3793,7 +3855,7 @@ let rewrite_defs_interpreter = [
("realise_mappings", rewrite_defs_realise_mappings);
("pat_string_append", rewrite_defs_pat_string_append);
("mapping_builtins", rewrite_defs_mapping_builtins);
- ("tuple_vector_assignments", rewrite_tuple_vector_assignments);
+ ("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
("remove_vector_concat", rewrite_defs_remove_vector_concat);