From fb8e8ce65d5ea13b92ec731820ed6c7a9a89f6a0 Mon Sep 17 00:00:00 2001 From: Brian Campbell Date: Mon, 15 Jan 2018 14:57:52 +0000 Subject: Support non-trivial literal patterns Previously we only did top-level literal pattern to guard conversion, this does it throughout any pattern --- src/rewrites.ml | 6 +- src/type_check.ml | 166 +++++++++++++++++++++++++++-------------------------- src/type_check.mli | 6 +- 3 files changed, 94 insertions(+), 84 deletions(-) (limited to 'src') diff --git a/src/rewrites.ml b/src/rewrites.ml index 32ffe54a..6158422e 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1084,7 +1084,7 @@ let remove_bitvector_pat (P_aux (_, (l, _)) as pat) = ; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot)) ; fP_Fpat = (fun (id,p) -> FP_Fpat (id,p false)) } in - let pat, env = bind_pat env + let pat, env = bind_pat_no_guard env (strip_pat ((fold_pat name_bitvector_roots pat) false)) (pat_typ_of pat) in @@ -1588,11 +1588,11 @@ let rewrite_register_ref_writes (Defs defs) = | BF_aux (BF_range (i, j), _) -> (i, j) | _ -> raise (Reporting_basic.err_unreachable l "unsupported field type") in let mk_num_exp i = mk_lit_exp (L_num i) in - let reg_pat, reg_env = bind_pat env (mk_pat (P_typ (rtyp, mk_pat (P_id (mk_id "reg"))))) rtyp in + let reg_pat, reg_env = bind_pat_no_guard env (mk_pat (P_typ (rtyp, mk_pat (P_id (mk_id "reg"))))) rtyp in let inferred_get = infer_exp reg_env (mk_exp (E_vector_subrange (mk_exp (E_id (mk_id "reg")), mk_num_exp i, mk_num_exp j))) in let ftyp = typ_of inferred_get in - let v_pat, v_env = bind_pat reg_env (mk_pat (P_typ (ftyp, mk_pat (P_id (mk_id "v"))))) ftyp in + let v_pat, v_env = bind_pat_no_guard reg_env (mk_pat (P_typ (ftyp, mk_pat (P_id (mk_id "v"))))) ftyp in let inferred_set = infer_exp v_env (mk_exp (E_vector_update_subrange (mk_exp (E_id (mk_id "reg")), mk_num_exp i, mk_num_exp j, mk_exp (E_id (mk_id "v"))))) in let set_args = P_aux (P_tup [reg_pat; v_pat], (l, Some (env, tuple_typ [rtyp; ftyp], no_effect))) in diff --git a/src/type_check.ml b/src/type_check.ml index a4b4bb39..9d235cb4 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -2054,16 +2054,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ annot_exp (E_case (inferred_exp, List.map (fun case -> check_case env inferred_typ case typ) cases)) typ | E_try (exp, cases), _ -> let checked_exp = crule check_exp env exp typ in - let check_case pat typ = match pat with - | Pat_aux (Pat_exp (pat, case), (l, _)) -> - let tpat, env = bind_pat env pat exc_typ in - Pat_aux (Pat_exp (tpat, crule check_exp env case typ), (l, None)) - | Pat_aux (Pat_when (pat, guard, case), (l, _)) -> - let tpat, env = bind_pat env pat exc_typ in - let checked_guard = check_exp env guard bool_typ in - Pat_aux (Pat_when (tpat, checked_guard, crule check_exp env case typ), (l, None)) - in - annot_exp (E_try (checked_exp, List.map (fun case -> check_case case typ) cases)) typ + annot_exp (E_try (checked_exp, List.map (fun case -> check_case env exc_typ case typ) cases)) typ | E_cons (x, xs), _ -> begin match is_list (Env.expand_synonyms env typ) with @@ -2118,11 +2109,11 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | LB_val (P_aux (P_typ (ptyp, _), _) as pat, bind) -> Env.wf_typ env ptyp; let checked_bind = crule check_exp env bind ptyp in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in annot_exp (E_let (LB_aux (LB_val (tpat, checked_bind), (let_loc, None)), crule check_exp env exp typ)) typ | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env bind in - let tpat, env = bind_pat env pat (typ_of inferred_bind) in + let tpat, env = bind_pat_no_guard env pat (typ_of inferred_bind) in annot_exp (E_let (LB_aux (LB_val (tpat, inferred_bind), (let_loc, None)), crule check_exp env exp typ)) typ end | E_app_infix (x, op, y), _ -> @@ -2181,7 +2172,7 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ | _ -> let inferred_bind = irule infer_exp env bind in inferred_bind, typ_of inferred_bind in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) let env = match bind_exp with | E_aux (E_assert (E_aux (E_constraint nc, _), _), _) -> @@ -2219,7 +2210,17 @@ let rec check_exp env (E_aux (exp_aux, (l, ())) as exp : unit exp) (Typ_aux (typ and check_case env pat_typ pexp typ = let pat,guard,case,((l,_) as annot) = destruct_pexp pexp in match bind_pat env pat pat_typ with - | tpat, env -> + | tpat, env, guards -> + let guard = match guard, guards with + | None, h::t -> Some (h,t) + | Some x, l -> Some (x,l) + | None, [] -> None + in + let guard = match guard with + | Some (h,t) -> + Some (List.fold_left (fun acc guard -> mk_exp (E_app_infix (acc, mk_id "&", guard))) h t) + | None -> None + in let checked_guard, env' = match guard with | None -> None, env | Some guard -> @@ -2229,16 +2230,6 @@ and check_case env pat_typ pexp typ = in let checked_case = crule check_exp env' case typ in construct_pexp (tpat, checked_guard, checked_case, (l, None)) - | exception (Type_error _ as typ_exn) -> - match pat with - | P_aux (P_lit lit, _) -> - let guard' = mk_exp (E_app_infix (mk_exp (E_id (mk_id "p#")), mk_id "==", mk_exp (E_lit lit))) in - let guard = match guard with - | None -> guard' - | Some guard -> mk_exp (E_app_infix (guard, mk_id "&", guard')) - in - check_case env pat_typ (Pat_aux (Pat_when (mk_pat (P_id (mk_id "p#")), guard, case), annot)) typ - | _ -> raise typ_exn (* type_coercion env exp typ takes a fully annoted (i.e. already type checked) expression exp, and attempts to cast (coerce) it to the @@ -2304,26 +2295,31 @@ and type_coercion_unify env (E_aux (_, (l, _)) as annotated_exp) typ = try_casts casts end +and bind_pat_no_guard env (P_aux (_,(l,_)) as pat) typ = + match bind_pat env pat typ with + | _, _, _::_ -> typ_error l "Literal patterns not supported here" + | tpat, env, [] -> tpat, env + and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) = typ_print ("Binding " ^ string_of_pat pat ^ " to " ^ string_of_typ typ); let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in let switch_typ (P_aux (pat_aux, (l, Some (env, _, eff)))) typ = P_aux (pat_aux, (l, Some (env, typ, eff))) in - let bind_tuple_pat (tpats, env) pat typ = - let tpat, env = bind_pat env pat typ in tpat :: tpats, env + let bind_tuple_pat (tpats, env, guards) pat typ = + let tpat, env, guards' = bind_pat env pat typ in tpat :: tpats, env, guards' @ guards in match pat_aux with | P_id v -> begin match Env.lookup_id v env with - | Local (Immutable, _) | Unbound -> annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env + | Local (Immutable, _) | Unbound -> annot_pat (P_id v) typ, Env.add_local v (Immutable, typ) env, [] | Local (Mutable, _) | Register _ -> typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat) - | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env + | Enum enum -> subtyp l env enum typ; annot_pat (P_id v) typ, env, [] | Union (typq, ctor_typ) -> begin try let _ = unify l env ctor_typ typ in - annot_pat (P_id v) typ, env + annot_pat (P_id v) typ, env, [] with | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m) end @@ -2336,34 +2332,34 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) let env = Env.add_typ_var kid BK_nat env in let ex_typ = typ_subst_nexp kid' (Nexp_var kid) ex_typ in let env = Env.add_constraint (nc_subst_nexp kid' (Nexp_var kid) nc) env in - let typed_pat, env = bind_pat env pat ex_typ in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat ex_typ in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | Some _, _ -> typ_error l ("Cannot bind type variable pattern against multiple argument existential") | None, Typ_aux (Typ_id id, _) when Id.compare id (mk_id "int") == 0 -> let env = Env.add_typ_var kid BK_nat env in - let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | None, Typ_aux (Typ_id id, _) when Id.compare id (mk_id "nat") == 0 -> let env = Env.add_typ_var kid BK_nat env in let env = Env.add_constraint (nc_gt (nvar kid) (nint 0)) env in - let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | None, Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp lo, _); Typ_arg_aux (Typ_arg_nexp hi, _)]), _) when Id.compare id (mk_id "range") == 0 -> let env = Env.add_typ_var kid BK_nat env in let env = Env.add_constraint (nc_and (nc_lteq lo (nvar kid)) (nc_lteq (nvar kid) hi)) env in - let typed_pat, env = bind_pat env pat (atom_typ (nvar kid)) in - annot_pat (P_var (typed_pat, kid)) typ, env + let typed_pat, env, guards = bind_pat env pat (atom_typ (nvar kid)) in + annot_pat (P_var (typed_pat, kid)) typ, env, guards | None, _ -> typ_error l ("Cannot bind type variable against non existential or numeric type") end - | P_wild -> annot_pat P_wild typ, env + | P_wild -> annot_pat P_wild typ, env, [] | P_cons (hd_pat, tl_pat) -> begin match Env.expand_synonyms env typ with | Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> - let hd_pat, env = bind_pat env hd_pat ltyp in - let tl_pat, env = bind_pat env tl_pat typ in - annot_pat (P_cons (hd_pat, tl_pat)) typ, env + let hd_pat, env, hd_guards = bind_pat env hd_pat ltyp in + let tl_pat, env, tl_guards = bind_pat env tl_pat typ in + annot_pat (P_cons (hd_pat, tl_pat)) typ, env, hd_guards @ tl_guards | _ -> typ_error l "Cannot match cons pattern against non-list type" end | P_list pats -> @@ -2371,32 +2367,32 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) match Env.expand_synonyms env typ with | Typ_aux (Typ_app (f, [Typ_arg_aux (Typ_arg_typ ltyp, _)]), _) when Id.compare f (mk_id "list") = 0 -> let rec process_pats env = function - | [] -> [], env + | [] -> [], env, [] | (pat :: pats) -> - let pat', env = bind_pat env pat ltyp in - let pats', env = process_pats env pats in - pat' :: pats', env + let pat', env, guards = bind_pat env pat ltyp in + let pats', env, guards' = process_pats env pats in + pat' :: pats', env, guards @ guards' in - let pats, env = process_pats env pats in - annot_pat (P_list pats) typ, env + let pats, env, guards = process_pats env pats in + annot_pat (P_list pats) typ, env, guards | _ -> typ_error l ("Cannot match list pattern " ^ string_of_pat pat ^ " against non-list type " ^ string_of_typ typ) end | P_tup [] -> begin match Env.expand_synonyms env typ with | Typ_aux (Typ_id typ_id, _) when string_of_id typ_id = "unit" -> - annot_pat (P_tup []) typ, env + annot_pat (P_tup []) typ, env, [] | _ -> typ_error l "Cannot match unit pattern against non-unit type" end | P_tup pats -> begin match Env.expand_synonyms env typ with | Typ_aux (Typ_tup typs, _) -> - let tpats, env = - try List.fold_left2 bind_tuple_pat ([], env) pats typs with + let tpats, env, guards = + try List.fold_left2 bind_tuple_pat ([], env, []) pats typs with | Invalid_argument _ -> typ_error l "Tuple pattern and tuple type have different length" in - annot_pat (P_tup (List.rev tpats)) typ, env + annot_pat (P_tup (List.rev tpats)) typ, env, guards | _ -> typ_error l "Cannot bind tuple pattern against non tuple type" end | P_app (f, pats) when Env.is_union_constructor f env -> @@ -2420,11 +2416,11 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) then typ_error l ("Quantifiers " ^ string_of_list ", " string_of_quant_item quants' ^ " not resolved in pattern " ^ string_of_pat pat) else (); let ret_typ' = subst_unifiers unifiers ret_typ in - let tpats, env = - try List.fold_left2 bind_tuple_pat ([], env) pats (untuple arg_typ') with + let tpats, env, guards = + try List.fold_left2 bind_tuple_pat ([], env, []) pats (untuple arg_typ') with | Invalid_argument _ -> typ_error l "Union constructor pattern arguments have incorrect length" in - annot_pat (P_app (f, List.rev tpats)) typ, env + annot_pat (P_app (f, List.rev tpats)) typ, env, guards with | Unification_error (l, m) -> typ_error l ("Unification error when pattern matching against union constructor: " ^ m) end @@ -2433,12 +2429,19 @@ and bind_pat env (P_aux (pat_aux, (l, ())) as pat) (Typ_aux (typ_aux, _) as typ) | P_app (f, _) when not (Env.is_union_constructor f env) -> typ_error l (string_of_id f ^ " is not a union constructor in pattern " ^ string_of_pat pat) | P_as (pat, id) -> - let (typed_pat, env) = bind_pat env pat typ in - annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env + let (typed_pat, env, guards) = bind_pat env pat typ in + annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env, guards | _ -> - let (inferred_pat, env) = infer_pat env pat in - subtyp l env (pat_typ_of inferred_pat) typ; - switch_typ inferred_pat typ, env + let (inferred_pat, env, guards) = infer_pat env pat in + match subtyp l env (pat_typ_of inferred_pat) typ with + | () -> switch_typ inferred_pat typ, env, guards + | exception (Type_error _ as typ_exn) -> + match pat_aux with + | P_lit lit -> + let guard = mk_exp (E_app_infix (mk_exp (E_id (mk_id "p#")), mk_id "==", mk_exp (E_lit lit))) in + let (typed_pat, env, guards) = bind_pat env (mk_pat (P_id (mk_id "p#"))) typ in + typed_pat, env, guard::guards + | _ -> raise typ_exn and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = let annot_pat pat typ = P_aux (pat, (l, Some (env, typ, no_effect))) in @@ -2450,31 +2453,32 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = typ_error l ("Cannot infer identifier in pattern " ^ string_of_pat pat ^ " - try adding a type annotation") | Local (Mutable, _) | Register _ -> typ_error l ("Cannot shadow mutable local or register in switch statement pattern " ^ string_of_pat pat) - | Enum enum -> annot_pat (P_id v) enum, env + | Enum enum -> annot_pat (P_id v) enum, env, [] end | P_typ (typ_annot, pat) -> Env.wf_typ env typ_annot; - let (typed_pat, env) = bind_pat env pat typ_annot in - annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env + let (typed_pat, env, guards) = bind_pat env pat typ_annot in + annot_pat (P_typ (typ_annot, typed_pat)) typ_annot, env, guards | P_lit lit -> - annot_pat (P_lit lit) (infer_lit env lit), env + annot_pat (P_lit lit) (infer_lit env lit), env, [] | P_vector (pat :: pats) -> - let fold_pats (pats, env) pat = - let typed_pat, env = bind_pat env pat bit_typ in - pats @ [typed_pat], env + let fold_pats (pats, env, guards) pat = + let typed_pat, env, guards' = bind_pat env pat bit_typ in + pats @ [typed_pat], env, guards' @ guards in - let ((typed_pat :: typed_pats) as pats), env = - List.fold_left fold_pats ([], env) (pat :: pats) in + let ((typed_pat :: typed_pats) as pats), env, guards = + List.fold_left fold_pats ([], env, []) (pat :: pats) in let len = nexp_simp (nint (List.length pats)) in let etyp = pat_typ_of typed_pat in List.map (fun pat -> typ_equality l env etyp (pat_typ_of pat)) pats; - annot_pat (P_vector pats) (lvector_typ env len etyp), env + annot_pat (P_vector pats) (lvector_typ env len etyp), env, guards | P_vector_concat (pat :: pats) -> - let fold_pats (pats, env) pat = - let inferred_pat, env = infer_pat env pat in - pats @ [inferred_pat], env + let fold_pats (pats, env, guards) pat = + let inferred_pat, env, guards' = infer_pat env pat in + pats @ [inferred_pat], env, guards' @ guards in - let (inferred_pat :: inferred_pats), env = List.fold_left fold_pats ([], env) (pat :: pats) in + let (inferred_pat :: inferred_pats), env, guards = + List.fold_left fold_pats ([], env, []) (pat :: pats) in let (_, len, _, vtyp) = destruct_vec_typ l env (pat_typ_of inferred_pat) in let fold_len len pat = let (_, len', _, vtyp') = destruct_vec_typ l env (pat_typ_of pat) in @@ -2482,10 +2486,12 @@ and infer_pat env (P_aux (pat_aux, (l, ())) as pat) = nsum len len' in let len = nexp_simp (List.fold_left fold_len len inferred_pats) in - annot_pat (P_vector_concat (inferred_pat :: inferred_pats)) (lvector_typ env len vtyp), env + annot_pat (P_vector_concat (inferred_pat :: inferred_pats)) (lvector_typ env len vtyp), env, guards | P_as (pat, id) -> - let (typed_pat, env) = infer_pat env pat in - annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), Env.add_local id (Immutable, pat_typ_of typed_pat) env + let (typed_pat, env, guards) = infer_pat env pat in + annot_pat (P_as (typed_pat, id)) (pat_typ_of typed_pat), + Env.add_local id (Immutable, pat_typ_of typed_pat) env, + guards | _ -> typ_error l ("Couldn't infer type of pattern " ^ string_of_pat pat) and bind_assignment env (LEXP_aux (lexp_aux, _) as lexp) (E_aux (_, (l, ())) as exp) = @@ -2905,7 +2911,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | _ -> let inferred_bind = irule infer_exp env bind in inferred_bind, typ_of inferred_bind in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in (* Propagate constraint assertions on the lhs of monadic binds to the rhs *) let env = match bind_exp with | E_aux (E_assert (E_aux (E_constraint nc, _), _), _) -> @@ -2923,7 +2929,7 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env bind in inferred_bind, pat, typ_of inferred_bind in - let tpat, env = bind_pat env pat ptyp in + let tpat, env = bind_pat_no_guard env pat ptyp in let inferred_exp = irule infer_exp env exp in annot_exp (E_let (LB_aux (LB_val (tpat, bind_exp), (let_loc, None)), inferred_exp)) (typ_of inferred_exp) | _ -> typ_error l ("Cannot infer type of: " ^ string_of_exp exp) @@ -3368,11 +3374,11 @@ let check_letdef env (LB_aux (letbind, (l, _))) = match letbind with | LB_val (P_aux (P_typ (typ_annot, pat), _), bind) -> let checked_bind = crule check_exp env (strip_exp bind) typ_annot in - let tpat, env = bind_pat env (strip_pat pat) typ_annot in + let tpat, env = bind_pat_no_guard env (strip_pat pat) typ_annot in [DEF_val (LB_aux (LB_val (P_aux (P_typ (typ_annot, tpat), (l, Some (env, typ_annot, no_effect))), checked_bind), (l, None)))], env | LB_val (pat, bind) -> let inferred_bind = irule infer_exp env (strip_exp bind) in - let tpat, env = bind_pat env (strip_pat pat) (typ_of inferred_bind) in + let tpat, env = bind_pat_no_guard env (strip_pat pat) (typ_of inferred_bind) in [DEF_val (LB_aux (LB_val (tpat, inferred_bind), (l, None)))], env end diff --git a/src/type_check.mli b/src/type_check.mli index 5066553e..d531a2a8 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -206,7 +206,11 @@ val prove : Env.t -> n_constraint -> bool val subtype_check : Env.t -> typ -> typ -> bool -val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t +val bind_pat : Env.t -> unit pat -> typ -> tannot pat * Env.t * unit Ast.exp list +(* Variant that doesn't introduce new guards for literal patterns, but raises + a type error instead. This should always be safe to use on patterns that + have previously been type checked. *) +val bind_pat_no_guard : Env.t -> unit pat -> typ -> tannot pat * Env.t (* Partial functions: The expressions and patterns passed to these functions must be guaranteed to have tannots of the form Some (env, -- cgit v1.2.3 From b4c367435b335f6a7160ed379408425c66c39ae1 Mon Sep 17 00:00:00 2001 From: Brian Campbell Date: Mon, 15 Jan 2018 18:14:58 +0000 Subject: Check monomorphisation case split size once for each pattern (rather than for each argument separately) --- src/ast_util.ml | 3 +++ src/ast_util.mli | 1 + src/monomorphise.ml | 56 +++++++++++++++++++++++++++++++++-------------------- 3 files changed, 39 insertions(+), 21 deletions(-) (limited to 'src') diff --git a/src/ast_util.ml b/src/ast_util.ml index 4ceb3e7f..e0a7de68 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -413,6 +413,9 @@ let id_loc = function let kid_loc = function | Kid_aux (_, l) -> l +let pat_loc = function + | P_aux (_, (l, _)) -> l + let def_loc = function | DEF_kind (KD_aux (_, (l, _))) | DEF_type (TD_aux (_, (l, _))) diff --git a/src/ast_util.mli b/src/ast_util.mli index 68955387..dff122be 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -161,6 +161,7 @@ val map_letbind_annot : ('a annot -> 'b annot) -> 'a letbind -> 'b letbind val id_loc : id -> Parse_ast.l val kid_loc : kid -> Parse_ast.l +val pat_loc : 'a pat -> Parse_ast.l val def_loc : 'a def -> Parse_ast.l (* For debugging and error messages only: Not guaranteed to produce diff --git a/src/monomorphise.ml b/src/monomorphise.ml index dd0edd64..b77b49cf 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -54,8 +54,7 @@ open Ast_util open Big_int open Type_check -let size_set_limit = 8 -let vector_split_limit = 4 +let size_set_limit = 16 let optmap v f = match v with @@ -1316,14 +1315,10 @@ let split_defs continue_anyway splits defs = | Typ_app (Id_aux (Id "vector",_), [_;Typ_arg_aux (Typ_arg_nexp len,_);_;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) -> (match len with | Nexp_aux (Nexp_constant sz,_) -> - if int_of_big_int sz <= vector_split_limit then - let lits = make_vectors (int_of_big_int sz) in - List.map (fun lit -> - P_aux (P_lit lit,(l,annot)), - [var,E_aux (E_lit lit,(new_l,annot))]) lits - else - cannot ("Refusing to split vector type of length " ^ string_of_big_int sz ^ - " (above limit " ^ string_of_int vector_split_limit ^ ")") + let lits = make_vectors (int_of_big_int sz) in + List.map (fun lit -> + P_aux (P_lit lit,(l,annot)), + [var,E_aux (E_lit lit,(new_l,annot))]) lits | _ -> cannot ("length not constant, " ^ string_of_nexp len) ) @@ -1494,6 +1489,19 @@ let split_defs continue_anyway splits defs = in p in + let check_split_size lst l = + let size = List.length lst in + if size > size_set_limit then + let open Reporting_basic in + let error = + Err_general (l, "Case split is too large (" ^ string_of_int size ^ + " > limit " ^ string_of_int size_set_limit ^ ")") + in if continue_anyway + then (print_error error; false) + else raise (Fatal_error error) + else true + in + let rec map_exp ((E_aux (e,annot)) as ea) = let re e = E_aux (e,annot) in match e with @@ -1556,13 +1564,16 @@ let split_defs continue_anyway splits defs = FE_aux (FE_Fexp (id,map_exp e),annot) and map_pexp = function | Pat_aux (Pat_exp (p,e),l) -> + let nosplit = [Pat_aux (Pat_exp (p,map_exp e),l)] in (match map_pat p with - | NoSplit -> [Pat_aux (Pat_exp (p,map_exp e),l)] + | NoSplit -> nosplit | VarSplit patsubsts -> - List.map (fun (pat',substs) -> - let exp' = subst_exp substs e in - Pat_aux (Pat_exp (pat', map_exp exp'),l)) - patsubsts + if check_split_size patsubsts (pat_loc p) then + List.map (fun (pat',substs) -> + let exp' = subst_exp substs e in + Pat_aux (Pat_exp (pat', map_exp exp'),l)) + patsubsts + else nosplit | ConstrSplit patnsubsts -> List.map (fun (pat',nsubst) -> let pat' = nexp_subst_pat nsubst pat' in @@ -1570,14 +1581,17 @@ let split_defs continue_anyway splits defs = Pat_aux (Pat_exp (pat', map_exp exp'),l) ) patnsubsts) | Pat_aux (Pat_when (p,e1,e2),l) -> + let nosplit = [Pat_aux (Pat_when (p,map_exp e1,map_exp e2),l)] in (match map_pat p with - | NoSplit -> [Pat_aux (Pat_when (p,map_exp e1,map_exp e2),l)] + | NoSplit -> nosplit | VarSplit patsubsts -> - List.map (fun (pat',substs) -> - let exp1' = subst_exp substs e1 in - let exp2' = subst_exp substs e2 in - Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)) - patsubsts + if check_split_size patsubsts (pat_loc p) then + List.map (fun (pat',substs) -> + let exp1' = subst_exp substs e1 in + let exp2' = subst_exp substs e2 in + Pat_aux (Pat_when (pat', map_exp exp1', map_exp exp2'),l)) + patsubsts + else nosplit | ConstrSplit patnsubsts -> List.map (fun (pat',nsubst) -> let pat' = nexp_subst_pat nsubst pat' in -- cgit v1.2.3 From 4d0162a5dbeea6286fbeecdc3cec3b4e55fada8c Mon Sep 17 00:00:00 2001 From: Brian Campbell Date: Tue, 16 Jan 2018 11:32:50 +0000 Subject: Another useful monomorphisation rewrite --- src/monomorphise.ml | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'src') diff --git a/src/monomorphise.ml b/src/monomorphise.ml index b77b49cf..9f67e93f 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -2778,6 +2778,15 @@ let rewrite_app env typ (id,args) = | _ -> E_app (id,args) + else if is_id env (Id "UInt") id then + let is_slice = is_id env (Id "slice") in + match args with + | [E_aux (E_app (slice1, [vector1; start1; length1]),_)] + when is_slice slice1 && not (is_constant length1) -> + E_app (mk_id "UInt_slice", [vector1; start1; length1]) + + | _ -> E_app (id,args) + else E_app (id,args) let rewrite_aux = function -- cgit v1.2.3 From 4b989cc6b82dac5008860a37f5d5b5396e9fbc89 Mon Sep 17 00:00:00 2001 From: Brian Campbell Date: Tue, 16 Jan 2018 15:56:35 +0000 Subject: Handle for loops correctly when rewriting size parameters --- src/monomorphise.ml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 9f67e93f..1e2a5cf4 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -1784,6 +1784,10 @@ let rewrite_size_parameters env (Defs defs) = { (compute_exp_alg KidSet.empty KidSet.union) with e_aux = (fun ((s,e),annot) -> KidSet.union s (sizes_of_annot annot), E_aux (e,annot)); e_let = (fun ((sl,lb),(s2,e2)) -> KidSet.union sl (KidSet.diff s2 (tyvars_bound_in_lb lb)), E_let (lb,e2)); + e_for = (fun (id,(s1,e1),(s2,e2),(s3,e3),ord,(s4,e4)) -> + let kid = mk_kid ("loop_" ^ string_of_id id) in + KidSet.union s1 (KidSet.union s2 (KidSet.union s3 (KidSet.remove kid s4))), + E_for (id,e1,e2,e3,ord,e4)); pat_exp = (fun ((sp,pat),(s,e)) -> KidSet.diff s (tyvars_bound_in_pat pat), Pat_exp (pat,e))} pexp) in @@ -1800,7 +1804,7 @@ let rewrite_size_parameters env (Defs defs) = | P_aux (P_tup ps,_) -> ps | _ -> [pat] in - let to_change = List.map + let to_change = Util.map_filter (fun kid -> let check (P_aux (_,(_,Some (env,typ,_)))) = match Env.expand_synonyms env typ with @@ -1813,9 +1817,10 @@ let rewrite_size_parameters env (Defs defs) = if Kid.compare kid kid' = 0 then Some kid else None | _ -> None in match findi check parameters with - | None -> raise (Reporting_basic.err_general l - ("Unable to find an argument for " ^ string_of_kid kid)) - | Some i -> i) + | None -> (Reporting_basic.print_error (Reporting_basic.Err_general (l, + ("Unable to find an argument for " ^ string_of_kid kid))); + None) + | Some i -> Some i) (KidSet.elements expose_tyvars) in let ik_compare (i,k) (i',k') = -- cgit v1.2.3 From c1cea9e24e2722b0e7376fbe1339564f96d29961 Mon Sep 17 00:00:00 2001 From: Thomas Bauereiss Date: Tue, 16 Jan 2018 16:42:54 +0000 Subject: Output more type annotations in Lem backend Keep track of which type variables have been bound in the function declaration, and allow those to be pretty-printed --- src/pretty_print_lem.ml | 184 +++++++++++++++++++++++------------------------- src/rewrites.ml | 1 + 2 files changed, 91 insertions(+), 94 deletions(-) (limited to 'src') diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 6a3d1293..3d55ef04 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -63,6 +63,12 @@ open Pretty_print_common let opt_sequential = ref false let opt_mwords = ref false +type context = { + early_ret : bool; + bound_nexps : NexpSet.t; +} +let empty_ctxt = { early_ret = false; bound_nexps = NexpSet.empty } + let print_to_from_interp_value = ref false let langlebar = string "<|" let ranglebar = string "|>" @@ -331,12 +337,12 @@ let doc_typ_lem, doc_atomic_typ_lem = length argument are checked for variables, and the latter only if it is a bitvector; for other types of vectors, the length is not pretty-printed in the type, and the start index is never pretty-printed in vector types. *) -let rec contains_t_pp_var (Typ_aux (t,a) as typ) = match t with +let rec contains_t_pp_var ctxt (Typ_aux (t,a) as typ) = match t with | Typ_id _ -> false | Typ_var _ -> true | Typ_exist _ -> true - | Typ_fn (t1,t2,_) -> contains_t_pp_var t1 || contains_t_pp_var t2 - | Typ_tup ts -> List.exists contains_t_pp_var ts + | Typ_fn (t1,t2,_) -> contains_t_pp_var ctxt t1 || contains_t_pp_var ctxt t2 + | Typ_tup ts -> List.exists (contains_t_pp_var ctxt) ts | Typ_app (c,targs) -> if Ast_util.is_number typ then false else if is_bitvector_typ typ then @@ -345,23 +351,22 @@ let rec contains_t_pp_var (Typ_aux (t,a) as typ) = match t with not (is_nexp_constant length || (!opt_mwords && match length with Nexp_aux (Nexp_var _,_) -> true | _ -> false)) - else List.exists contains_t_arg_pp_var targs -and contains_t_arg_pp_var (Typ_arg_aux (targ, _)) = match targ with - | Typ_arg_typ t -> contains_t_pp_var t - | Typ_arg_nexp nexp -> not (is_nexp_constant (nexp_simp nexp)) + else List.exists (contains_t_arg_pp_var ctxt) targs +and contains_t_arg_pp_var ctxt (Typ_arg_aux (targ, _)) = match targ with + | Typ_arg_typ t -> contains_t_pp_var ctxt t + | Typ_arg_nexp nexp -> + let nexp = nexp_simp nexp in + not (is_nexp_constant nexp || NexpSet.mem nexp ctxt.bound_nexps) | _ -> false -let doc_tannot_lem eff typ = - if contains_t_pp_var typ then empty +let doc_tannot_lem ctxt eff typ = + if contains_t_pp_var ctxt typ then empty else let ta = doc_typ_lem typ in if eff then string " : M " ^^ parens ta else string " : " ^^ ta -(* doc_lit_lem gets as an additional parameter the type information from the - * expression around it: that's a hack, but how else can we distinguish between - * undefined values of different types ? *) -let doc_lit_lem in_pat (L_aux(lit,l)) a = +let doc_lit_lem (L_aux(lit,l)) = match lit with | L_unit -> utf8string "()" | L_zero -> utf8string "B0" @@ -371,24 +376,12 @@ let doc_lit_lem in_pat (L_aux(lit,l)) a = | L_num i -> let ipp = string_of_big_int i in utf8string ( - if in_pat then "("^ipp^":nn)" - else if lt_big_int i zero_big_int then "((0"^ipp^"):ii)" + if lt_big_int i zero_big_int then "((0"^ipp^"):ii)" else "("^ipp^":ii)") | L_hex n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*)*) | L_bin n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*)*) | L_undef -> - (match a with - | Some (_, (Typ_aux (t,_) as typ), _) -> - (match t with - | Typ_id (Id_aux (Id "bit", _)) - | Typ_app (Id_aux (Id "register", _),_) -> utf8string "UndefinedRegister 0" - | Typ_id (Id_aux (Id "string", _)) -> utf8string "\"\"" - | _ -> - let ta = if contains_t_pp_var typ then empty - else doc_tannot_lem false typ in - parens - ((utf8string "(failwith \"undefined value of unsupported type\")") ^^ ta)) - | _ -> utf8string "(failwith \"undefined value of unsupported type\")") + utf8string "(failwith \"undefined value of unsupported type\")" | L_string s -> utf8string ("\"" ^ s ^ "\"") | L_real s -> (* Lem does not support decimal syntax, so we translate a string @@ -475,44 +468,44 @@ let is_ctor env id = match Env.lookup_id id env with (*Note: vector concatenation, literal vectors, indexed vectors, and record should be removed prior to pp. The latter two have never yet been seen *) -let rec doc_pat_lem apat_needed (P_aux (p,(l,annot)) as pa) = match p with +let rec doc_pat_lem ctxt apat_needed (P_aux (p,(l,annot)) as pa) = match p with | P_app(id, ((_ :: _) as pats)) -> let ppp = doc_unop (doc_id_lem_ctor id) - (parens (separate_map comma (doc_pat_lem true) pats)) in + (parens (separate_map comma (doc_pat_lem ctxt true) pats)) in if apat_needed then parens ppp else ppp | P_app(id,[]) -> doc_id_lem_ctor id - | P_lit lit -> doc_lit_lem true lit annot + | P_lit lit -> doc_lit_lem lit | P_wild -> underscore | P_id id -> begin match id with | Id_aux (Id "None",_) -> string "Nothing" (* workaround temporary issue *) | _ -> doc_id_lem id end - | P_var(p,kid) -> doc_pat_lem true p - | P_as(p,id) -> parens (separate space [doc_pat_lem true p; string "as"; doc_id_lem id]) + | P_var(p,kid) -> doc_pat_lem ctxt true p + | P_as(p,id) -> parens (separate space [doc_pat_lem ctxt true p; string "as"; doc_id_lem id]) | P_typ(Typ_aux (Typ_tup typs, _), P_aux (P_tup pats, _)) -> (* Isabelle does not seem to like type-annotated tuple patterns; it gives a syntax error. Avoid this by annotating the tuple elements instead *) let doc_elem typ (P_aux (_, annot) as pat) = - doc_pat_lem true (P_aux (P_typ (typ, pat), annot)) in + doc_pat_lem ctxt true (P_aux (P_typ (typ, pat), annot)) in parens (separate comma_sp (List.map2 doc_elem typs pats)) | P_typ(typ,p) -> - let doc_p = doc_pat_lem true p in - if contains_t_pp_var typ then doc_p + let doc_p = doc_pat_lem ctxt true p in + if contains_t_pp_var ctxt typ then doc_p else parens (doc_op colon doc_p (doc_typ_lem typ)) | P_vector pats -> let ppp = (separate space) - [string "Vector";brackets (separate_map semi (doc_pat_lem true) pats);underscore;underscore] in + [string "Vector";brackets (separate_map semi (doc_pat_lem ctxt true) pats);underscore;underscore] in if apat_needed then parens ppp else ppp | P_vector_concat pats -> raise (Reporting_basic.err_unreachable l "vector concatenation patterns should have been removed before pretty-printing") | P_tup pats -> (match pats with - | [p] -> doc_pat_lem apat_needed p - | _ -> parens (separate_map comma_sp (doc_pat_lem false) pats)) - | P_list pats -> brackets (separate_map semi (doc_pat_lem false) pats) (*Never seen but easy in lem*) - | P_cons (p,p') -> doc_op (string "::") (doc_pat_lem true p) (doc_pat_lem true p') + | [p] -> doc_pat_lem ctxt apat_needed p + | _ -> parens (separate_map comma_sp (doc_pat_lem ctxt false) pats)) + | P_list pats -> brackets (separate_map semi (doc_pat_lem ctxt false) pats) (*Never seen but easy in lem*) + | P_cons (p,p') -> doc_op (string "::") (doc_pat_lem ctxt true p) (doc_pat_lem ctxt true p') | P_record (_,_) -> empty (* TODO *) let rec typ_needs_printed (Typ_aux (t,_) as typ) = match t with @@ -544,13 +537,13 @@ let typ_id_of (Typ_aux (typ, l)) = match typ with let prefix_recordtype = true let report = Reporting_basic.err_unreachable let doc_exp_lem, doc_let_lem = - let rec top_exp (early_ret : bool) (aexp_needed : bool) + let rec top_exp (ctxt : context) (aexp_needed : bool) (E_aux (e, (l,annot)) as full_exp) = - let expY = top_exp early_ret true in - let expN = top_exp early_ret false in - let expV = top_exp early_ret in + let expY = top_exp ctxt true in + let expN = top_exp ctxt false in + let expV = top_exp ctxt in let liftR doc = - if early_ret && effectful (effect_of full_exp) + if ctxt.early_ret && effectful (effect_of full_exp) then separate space [string "liftR"; parens (doc)] else doc in match e with @@ -570,10 +563,10 @@ let doc_exp_lem, doc_let_lem = doc_id_lem id in liftR ((prefix 2 1) (string "write_reg_field_range") - (align (doc_lexp_deref_lem early_ret le ^/^ + (align (doc_lexp_deref_lem ctxt le ^/^ field_ref ^/^ expY e2 ^/^ expY e3 ^/^ expY e))) | _ -> - let deref = doc_lexp_deref_lem early_ret le in + let deref = doc_lexp_deref_lem ctxt le in liftR ((prefix 2 1) (string "write_reg_range") (align (deref ^/^ expY e2 ^/^ expY e3) ^/^ expY e))) @@ -590,10 +583,10 @@ let doc_exp_lem, doc_let_lem = let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot fannot)) then "write_reg_field_bit" else "write_reg_field_pos" in liftR ((prefix 2 1) (string call) - (align (doc_lexp_deref_lem early_ret le ^/^ + (align (doc_lexp_deref_lem ctxt le ^/^ field_ref ^/^ expY e2 ^/^ expY e))) | LEXP_aux (_, lannot) -> - let deref = doc_lexp_deref_lem early_ret le in + let deref = doc_lexp_deref_lem ctxt le in let call = if is_bitvector_typ (Env.base_typ_of (env_of full_exp) (typ_of_annot lannot)) then "write_reg_bit" else "write_reg_pos" in liftR ((prefix 2 1) (string call) (deref ^/^ expY e2 ^/^ expY e)) @@ -607,10 +600,10 @@ let doc_exp_lem, doc_let_lem = string "set_field"*) in liftR ((prefix 2 1) (string "write_reg_field") - (doc_lexp_deref_lem early_ret le ^^ space ^^ + (doc_lexp_deref_lem ctxt le ^^ space ^^ field_ref ^/^ expY e)) | _ -> - liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem early_ret le ^/^ expY e))) + liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem ctxt le ^/^ expY e))) | E_vector_append(le,re) -> raise (Reporting_basic.err_unreachable l "E_vector_append should have been rewritten before pretty-printing") @@ -626,7 +619,7 @@ let doc_exp_lem, doc_let_lem = | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> raise (report l "E_for should have been removed till now") | E_let(leb,e) -> - let epp = let_exp early_ret leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in + let epp = let_exp ctxt leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in if aexp_needed then parens epp else epp | E_app(f,args) -> begin match f with @@ -681,8 +674,8 @@ let doc_exp_lem, doc_let_lem = | [exp] -> let epp = separate space [string "early_return"; expY exp] in let aexp_needed, tepp = - if contains_t_pp_var (typ_of exp) || - contains_t_pp_var (typ_of full_exp) then + if contains_t_pp_var ctxt (typ_of exp) || + contains_t_pp_var ctxt (typ_of full_exp) then aexp_needed, epp else let tannot = separate space [string "MR"; @@ -721,7 +714,7 @@ let doc_exp_lem, doc_let_lem = let t = (*Env.base_typ_of (env_of full_exp)*) (typ_of full_exp) in let eff = effect_of full_exp in if typ_needs_printed (Env.base_typ_of (env_of full_exp) t) - then (align epp ^^ (doc_tannot_lem (effectful eff) t), true) + then (align epp ^^ (doc_tannot_lem ctxt (effectful eff) t), true) else (epp, aexp_needed) in liftR (if aexp_needed then parens (align taepp) else taepp) end @@ -743,7 +736,7 @@ let doc_exp_lem, doc_let_lem = let field_f = doc_id_lem tid ^^ underscore ^^ doc_id_lem id ^^ dot ^^ string "get_field" in let (ta,aexp_needed) = if typ_needs_printed t - then (doc_tannot_lem (effectful eff) t, true) + then (doc_tannot_lem ctxt (effectful eff) t, true) else (empty, aexp_needed) in let epp = field_f ^^ space ^^ (expY fexp) in if aexp_needed then parens (align epp ^^ ta) else (epp ^^ ta) @@ -766,11 +759,11 @@ let doc_exp_lem, doc_let_lem = if has_effect eff BE_rreg then let epp = separate space [string "read_reg";doc_id_lem id] in if is_bitvector_typ base_typ - then liftR (parens (epp ^^ doc_tannot_lem true base_typ)) + then liftR (parens (epp ^^ doc_tannot_lem ctxt true base_typ)) else liftR epp else if is_ctor env id then doc_id_lem_ctor id else doc_id_lem id - | E_lit lit -> doc_lit_lem false lit annot + | E_lit lit -> doc_lit_lem lit | E_cast(typ,e) -> expV aexp_needed e | E_tuple exps -> @@ -784,7 +777,7 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in let epp = anglebars (space ^^ (align (separate_map (semi_sp ^^ break 1) - (doc_fexp early_ret recordtyp) fexps)) ^^ space) in + (doc_fexp ctxt recordtyp) fexps)) ^^ space) in if aexp_needed then parens epp else epp | E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) -> let recordtyp = match annot with @@ -793,7 +786,7 @@ let doc_exp_lem, doc_let_lem = when Env.is_record tid env -> tid | _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in - anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp early_ret recordtyp) fexps)) + anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps)) | E_vector exps -> let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (start, len, order, etyp) = @@ -821,7 +814,7 @@ let doc_exp_lem, doc_let_lem = let (epp,aexp_needed) = if is_bit_typ etyp && !opt_mwords then let bepp = string "vec_to_bvec" ^^ space ^^ parens (align epp) in - (bepp ^^ doc_tannot_lem false t, true) + (bepp ^^ doc_tannot_lem ctxt false t, true) else (epp,aexp_needed) in if aexp_needed then parens (align epp) else epp | E_vector_update(v,e1,e2) -> @@ -852,15 +845,15 @@ let doc_exp_lem, doc_let_lem = let epp = group ((separate space [string "match"; only_integers e; string "with"]) ^/^ - (separate_map (break 1) (doc_case early_ret) pexps) ^/^ + (separate_map (break 1) (doc_case ctxt) pexps) ^/^ (string "end")) in if aexp_needed then parens (align epp) else align epp | E_try (e, pexps) -> if effectful (effect_of e) then - let try_catch = if early_ret then "try_catchR" else "try_catch" in + let try_catch = if ctxt.early_ret then "try_catchR" else "try_catch" in let epp = group ((separate space [string try_catch; expY e; string "(function "]) ^/^ - (separate_map (break 1) (doc_case early_ret) pexps) ^/^ + (separate_map (break 1) (doc_case ctxt) pexps) ^/^ (string "end)")) in if aexp_needed then parens (align epp) else align epp else @@ -885,20 +878,20 @@ let doc_exp_lem, doc_let_lem = (separate space [expV b e1; string ">>"]) ^^ hardline ^^ expN e2 | _ -> (separate space [expV b e1; string ">>= fun"; - doc_pat_lem true pat;arrow]) ^^ hardline ^^ expN e2 in + doc_pat_lem ctxt true pat;arrow]) ^^ hardline ^^ expN e2 in if aexp_needed then parens (align epp) else epp | E_internal_return (e1) -> separate space [string "return"; expY e1] | E_sizeof nexp -> (match nexp_simp nexp with - | Nexp_aux (Nexp_constant i, _) -> doc_lit_lem false (L_aux (L_num i, l)) annot + | Nexp_aux (Nexp_constant i, _) -> doc_lit_lem (L_aux (L_num i, l)) | _ -> raise (Reporting_basic.err_unreachable l "pretty-printing non-constant sizeof expressions to Lem not supported")) | E_return r -> let ret_monad = if !opt_sequential then " : MR regstate" else " : MR" in let ta = - if contains_t_pp_var (typ_of full_exp) || contains_t_pp_var (typ_of r) + if contains_t_pp_var ctxt (typ_of full_exp) || contains_t_pp_var ctxt (typ_of r) then empty else separate space [string ret_monad; @@ -910,33 +903,33 @@ let doc_exp_lem, doc_let_lem = | E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_exp_user _ -> raise (Reporting_basic.err_unreachable l "unsupported internal expression encountered while pretty-printing") - and let_exp early_ret (LB_aux(lb,_)) = match lb with + and let_exp ctxt (LB_aux(lb,_)) = match lb with | LB_val(pat,e) -> prefix 2 1 - (separate space [string "let"; doc_pat_lem true pat; equals]) - (top_exp early_ret false e) + (separate space [string "let"; doc_pat_lem ctxt true pat; equals]) + (top_exp ctxt false e) - and doc_fexp early_ret recordtyp (FE_aux(FE_Fexp(id,e),_)) = + and doc_fexp ctxt recordtyp (FE_aux(FE_Fexp(id,e),_)) = let fname = if prefix_recordtype then (string (string_of_id recordtyp ^ "_")) ^^ doc_id_lem id else doc_id_lem id in - group (doc_op equals fname (top_exp early_ret true e)) + group (doc_op equals fname (top_exp ctxt true e)) - and doc_case early_ret = function + and doc_case ctxt = function | Pat_aux(Pat_exp(pat,e),_) -> - group (prefix 3 1 (separate space [pipe; doc_pat_lem false pat;arrow]) - (group (top_exp early_ret false e))) + group (prefix 3 1 (separate space [pipe; doc_pat_lem ctxt false pat;arrow]) + (group (top_exp ctxt false e))) | Pat_aux(Pat_when(_,_,_),(l,_)) -> raise (Reporting_basic.err_unreachable l "guarded pattern expression should have been rewritten before pretty-printing") - and doc_lexp_deref_lem early_ret ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with + and doc_lexp_deref_lem ctxt ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with | LEXP_field (le,id) -> - parens (separate empty [doc_lexp_deref_lem early_ret le;dot;doc_id_lem id]) + parens (separate empty [doc_lexp_deref_lem ctxt le;dot;doc_id_lem id]) | LEXP_id id -> doc_id_lem id | LEXP_cast (typ,id) -> doc_id_lem id - | LEXP_tup lexps -> parens (separate_map comma_sp (doc_lexp_deref_lem early_ret) lexps) + | LEXP_tup lexps -> parens (separate_map comma_sp (doc_lexp_deref_lem ctxt) lexps) | _ -> raise (Reporting_basic.err_unreachable l ("doc_lexp_deref_lem: Unsupported lexp")) (* expose doc_exp_lem and doc_let *) @@ -980,7 +973,7 @@ let doc_typdef_lem (TD_aux(td, (l, annot))) = match td with mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), [mk_typ_arg (Typ_arg_typ rectyp); mk_typ_arg (Typ_arg_typ ftyp)])) in - let rfannot = doc_tannot_lem false reftyp in + let rfannot = doc_tannot_lem empty_ctxt false reftyp in let get, set = string "rec_val" ^^ dot ^^ fname fid, anglebars (space ^^ string "rec_val with " ^^ @@ -1200,7 +1193,7 @@ let doc_typdef_lem (TD_aux(td, (l, annot))) = match td with let ord = Ord_aux ((if dir_b then Ord_inc else Ord_dec), Parse_ast.Unknown) in let size = if dir_b then add_big_int (sub_big_int i2 i1) unit_big_int else add_big_int (sub_big_int i1 i2) unit_big_int in let vtyp = vector_typ (nconstant i1) (nconstant size) ord bit_typ in - let tannot = doc_tannot_lem false vtyp in + let tannot = doc_tannot_lem empty_ctxt false vtyp in let doc_rid (r,id) = parens (separate comma_sp [string_lit (doc_id_lem id); doc_range_lem r;]) in let doc_rids = group (separate_map (semi ^^ (break 1)) doc_rid rs) in @@ -1262,17 +1255,20 @@ let doc_rec_lem (Rec_aux(r,_)) = match r with let doc_tannot_opt_lem (Typ_annot_opt_aux(t,_)) = match t with | Typ_annot_opt_some(tq,typ) -> (*doc_typquant_lem tq*) (doc_typ_lem typ) -let doc_fun_body_lem exp = - let early_ret =contains_early_return exp in - let doc_exp = doc_exp_lem early_ret false exp in - if early_ret +let doc_fun_body_lem ctxt exp = + let doc_exp = doc_exp_lem ctxt false exp in + if ctxt.early_ret then align (string "catch_early_return" ^//^ parens (doc_exp)) else doc_exp -let doc_funcl_lem (FCL_aux(FCL_Funcl(id,pexp),_)) = +let doc_funcl_lem (FCL_aux(FCL_Funcl(id, pexp), annot)) = + let typ = typ_of_annot annot in let pat,guard,exp,(l,_) = destruct_pexp pexp in + let ctxt = + { early_ret = contains_early_return exp; + bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) } in let pats, bind = untuple_args_pat pat in - let patspp = separate_map space (doc_pat_lem true) pats in + let patspp = separate_map space (doc_pat_lem ctxt true) pats in let _ = match guard with | None -> () | _ -> @@ -1280,7 +1276,7 @@ let doc_funcl_lem (FCL_aux(FCL_Funcl(id,pexp),_)) = "guarded pattern expression should have been rewritten before pretty-printing") in group (prefix 3 1 (separate space [doc_id_lem id; patspp; equals]) - (doc_fun_body_lem (bind exp))) + (doc_fun_body_lem ctxt (bind exp))) let get_id = function | [] -> failwith "FD_function with empty list" @@ -1294,8 +1290,8 @@ let doc_fundef_rhs_lem (FD_aux(FD_function(r, typa, efa, funcls),fannot) as fd) let doc_mutrec_lem = function | [] -> failwith "DEF_internal_mutrec with empty function list" | fundefs -> - string "let rec " ^^ - separate_map (hardline ^^ string "and ") doc_fundef_rhs_lem fundefs + string "let rec " ^^ + separate_map (hardline ^^ string "and ") doc_fundef_rhs_lem fundefs let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) = match fcls with @@ -1348,13 +1344,13 @@ let rec doc_fundef_lem (FD_aux(FD_function(r, typa, efa, fcls),fannot) as fd) = let named_args = if argspat = [] then [unit_pat] else named_argspat in let doc_arg idx (P_aux (p,(l,a))) = match p with | P_as (pat,id) -> doc_id_lem id - | P_lit lit -> doc_lit_lem false lit a + | P_lit lit -> doc_lit_lem lit | P_id id -> doc_id_lem id | _ -> string ("arg" ^ string_of_int idx) in let clauses = clauses ^^ (break 1) ^^ (separate space - [pipe;doc_pat_lem false named_pat;arrow; + [pipe;doc_pat_lem empty_ctxt false named_pat;arrow; string aux_fname; separate space (List.mapi doc_arg named_args)]) in (already_used_fnames,auxiliary_functions,clauses) @@ -1452,7 +1448,7 @@ let doc_regtype_fields (tname, (n1, n2, fields)) = mk_typ (Typ_app (Id_aux (Id "field_ref", Parse_ast.Unknown), [mk_typ_arg (Typ_arg_typ (mk_id_typ (mk_id tname))); mk_typ_arg (Typ_arg_typ ftyp)])) in - let rfannot = doc_tannot_lem false reftyp in + let rfannot = doc_tannot_lem empty_ctxt false reftyp in doc_op equals (concat [string "let "; parens (concat [string tname; underscore; doc_id_lem fid; rfannot])]) (concat [ @@ -1479,7 +1475,7 @@ let rec doc_def_lem regtypes def = if is_field_accessor regtypes fdef then (doc_fdef, empty) else (empty, doc_fdef) | DEF_internal_mutrec fundefs -> (empty, doc_mutrec_lem fundefs ^/^ hardline) - | DEF_val lbind -> (empty,group (doc_let_lem false lbind) ^/^ hardline) + | DEF_val lbind -> (empty,group (doc_let_lem empty_ctxt lbind) ^/^ hardline) | DEF_scattered sdef -> failwith "doc_def_lem: shoulnd't have DEF_scattered at this point" | DEF_kind _ -> (empty,empty) @@ -1544,7 +1540,7 @@ let doc_regstate_lem registers = E_record (FES_aux (FES_Fexps (List.map initreg registers, false), annot)), (l, Some (Env.empty, mk_id_typ (mk_id "regstate"), no_effect))) in - doc_op equals (string "let initial_regstate") (doc_exp_lem false false exp) + doc_op equals (string "let initial_regstate") (doc_exp_lem empty_ctxt false exp) else empty in doc_typdef_lem (TD_aux (regstate, annot)), diff --git a/src/rewrites.ml b/src/rewrites.ml index 6158422e..950013ff 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2281,6 +2281,7 @@ let rewrite_defs_letbind_effects = let exp = if newreturn then (* let typ = try typ_of exp with _ -> unit_typ in *) + let exp = annot_exp (E_cast (typ_of exp, exp)) l (env_of exp) (typ_of exp) in annot_exp (E_internal_return exp) l (env_of exp) (typ_of exp) else exp in -- cgit v1.2.3 From 3b252c7e6b37f0d8be7fbeba75331f7299072b1d Mon Sep 17 00:00:00 2001 From: Thomas Bauereiss Date: Tue, 16 Jan 2018 16:56:46 +0000 Subject: Fix problem with let-bindings in pattern guards Monomorphisation sometimes produces pattern guard with let-bindings, e.g. | ... if (let regsize = size_itself(regsize) in eq(regsize, 32)) -> ... Previously, the rewriting pass for let-bindings (and pattern guards) pulled these out of the guard condition and into the same scope as the case expression, which potentially clashed with let-bindings for the same variables in that case expression. The rewriter now leaves let-bindings in place within pure if-conditions, solving this problem. --- src/monomorphise.ml | 2 +- src/rewrites.ml | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) (limited to 'src') diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 1e2a5cf4..dd1b4aa3 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -1877,7 +1877,7 @@ let rewrite_size_parameters env (Defs defs) = let body = List.fold_left add_var_rebind body vars in let guard = match guard with | None -> None - | Some exp -> Some (List.fold_left add_var_rebind body vars) + | Some exp -> Some (List.fold_left add_var_rebind exp vars) in pat,guard,body | exception Not_found -> pat,guard,body diff --git a/src/rewrites.ml b/src/rewrites.ml index 950013ff..7e852092 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2313,13 +2313,15 @@ let rewrite_defs_letbind_effects = n_exp_nameL exps (fun exps -> k (rewrap (E_tuple exps))) | E_if (exp1,exp2,exp3) -> - n_exp_name exp1 (fun exp1 -> - let (E_aux (_,annot2)) = exp2 in - let (E_aux (_,annot3)) = exp3 in - let newreturn = effectful exp2 || effectful exp3 in - let exp2 = n_exp_term newreturn exp2 in - let exp3 = n_exp_term newreturn exp3 in - k (rewrap (E_if (exp1,exp2,exp3)))) + let e_if exp1 = + let (E_aux (_,annot2)) = exp2 in + let (E_aux (_,annot3)) = exp3 in + let newreturn = effectful exp2 || effectful exp3 in + let exp2 = n_exp_term newreturn exp2 in + let exp3 = n_exp_term newreturn exp3 in + k (rewrap (E_if (exp1,exp2,exp3))) + in + if value exp1 then e_if (n_exp_term false exp1) else n_exp_name exp1 e_if | E_for (id,start,stop,by,dir,body) -> n_exp_name start (fun start -> n_exp_name stop (fun stop -> -- cgit v1.2.3 From 659151f8c5000885764a7a4153affe84a450ab1d Mon Sep 17 00:00:00 2001 From: Thomas Bauereiss Date: Wed, 17 Jan 2018 14:27:28 +0000 Subject: Fix use of nexps in type annotations when not using machine words --- src/pretty_print_lem.ml | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) (limited to 'src') diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 3d55ef04..2054a8d7 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -425,29 +425,30 @@ let doc_typquant_lem (TypQ_aux(tq,_)) vars_included typ = match tq with machine words. Often these will be unnecessary, but this simple approach will do for now. *) -let rec typeclass_nexps (Typ_aux(t,_)) = match t with -| Typ_id _ -| Typ_var _ - -> NexpSet.empty -| Typ_fn (t1,t2,_) -> NexpSet.union (typeclass_nexps t1) (typeclass_nexps t2) -| Typ_tup ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts) -| Typ_app (Id_aux (Id "vector",_), - [_;Typ_arg_aux (Typ_arg_nexp size_nexp,_); - _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) -| Typ_app (Id_aux (Id "itself",_), - [Typ_arg_aux (Typ_arg_nexp size_nexp,_)]) -> - let size_nexp = nexp_simp size_nexp in - if is_nexp_constant size_nexp then NexpSet.empty else - NexpSet.singleton (orig_nexp size_nexp) -| Typ_app _ -> NexpSet.empty -| Typ_exist (kids,_,t) -> NexpSet.empty (* todo *) +let rec typeclass_nexps (Typ_aux(t,_)) = + if !opt_mwords then + match t with + | Typ_id _ + | Typ_var _ + -> NexpSet.empty + | Typ_fn (t1,t2,_) -> NexpSet.union (typeclass_nexps t1) (typeclass_nexps t2) + | Typ_tup ts -> List.fold_left NexpSet.union NexpSet.empty (List.map typeclass_nexps ts) + | Typ_app (Id_aux (Id "vector",_), + [_;Typ_arg_aux (Typ_arg_nexp size_nexp,_); + _;Typ_arg_aux (Typ_arg_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) + | Typ_app (Id_aux (Id "itself",_), + [Typ_arg_aux (Typ_arg_nexp size_nexp,_)]) -> + let size_nexp = nexp_simp size_nexp in + if is_nexp_constant size_nexp then NexpSet.empty else + NexpSet.singleton (orig_nexp size_nexp) + | Typ_app _ -> NexpSet.empty + | Typ_exist (kids,_,t) -> NexpSet.empty (* todo *) + else NexpSet.empty let doc_typclasses_lem t = - if !opt_mwords then - let nexps = typeclass_nexps t in - if NexpSet.is_empty nexps then (empty, NexpSet.empty) else - (separate_map comma_sp (fun nexp -> string "Size " ^^ doc_nexp_lem nexp) (NexpSet.elements nexps) ^^ string " => ", nexps) - else (empty, NexpSet.empty) + let nexps = typeclass_nexps t in + if NexpSet.is_empty nexps then (empty, NexpSet.empty) else + (separate_map comma_sp (fun nexp -> string "Size " ^^ doc_nexp_lem nexp) (NexpSet.elements nexps) ^^ string " => ", nexps) let doc_typschm_lem quants (TypSchm_aux(TypSchm_ts(tq,t),_)) = let pt = doc_typ_lem t in -- cgit v1.2.3 From 01a7c0eb317610acee9a79b24a5fa36ee78b6e07 Mon Sep 17 00:00:00 2001 From: Thomas Bauereiss Date: Wed, 17 Jan 2018 14:46:46 +0000 Subject: Rewrite topological sorting Use Tarjan's algorithm for finding strongly connected components (and finding a topological sorting of components at the same time), in order to properly take into account mutually recursive functions. The sorting is stable, i.e., definitions are only moved when necessary. Exceptions to this are statements that do not have any dependencies: default bitvector order declarations, operator fixity declarations, and top-level comments. These are moved to the beginning (like with the previous sorting implementation). Any dependency cycles that are found are additionally printed to the console in dot-format, for easy visualisation with graphviz. --- src/spec_analysis.ml | 387 ++++++++++++++++++--------------------------------- 1 file changed, 136 insertions(+), 251 deletions(-) (limited to 'src') diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml index b35bc48f..19f7b085 100644 --- a/src/spec_analysis.ml +++ b/src/spec_analysis.ml @@ -64,133 +64,6 @@ let set_to_string n = list_to_string (Nameset.elements n) -(*Query a spec for its default order if one is provided. Assumes Inc if not *) -(* let get_default_order_sp (DT_aux(spec,_)) = - match spec with - | DT_order (Ord_aux(o,_)) -> - (match o with - | Ord_inc -> Some {order = Oinc} - | Ord_dec -> Some { order = Odec} - | _ -> Some {order = Oinc}) - | _ -> None - -let get_default_order_def = function - | DEF_default def_spec -> get_default_order_sp def_spec - | _ -> None - -let rec default_order (Defs defs) = - match defs with - | [] -> { order = Oinc } (*When no order is specified, we assume that it's inc*) - | def::defs -> - match get_default_order_def def with - | None -> default_order (Defs defs) - | Some o -> o *) - -(*Is within range*) - -(* let check_in_range (candidate : big_int) (range : typ) : bool = - match range.t with - | Tapp("range", [TA_nexp min; TA_nexp max]) | Tabbrev(_,{t=Tapp("range", [TA_nexp min; TA_nexp max])}) -> - let min,max = - match min.nexp,max.nexp with - | (Nconst min, Nconst max) - | (Nconst min, N2n(_, Some max)) - | (N2n(_, Some min), Nconst max) - | (N2n(_, Some min), N2n(_, Some max)) - -> min, max - | (Nneg n, Nconst max) | (Nneg n, N2n(_, Some max))-> - (match n.nexp with - | Nconst abs_min | N2n(_,Some abs_min) -> - (minus_big_int abs_min), max - | _ -> assert false (*Put a better error message here*)) - | (Nconst min,Nneg n) | (N2n(_, Some min), Nneg n) -> - (match n.nexp with - | Nconst abs_max | N2n(_,Some abs_max) -> - min, (minus_big_int abs_max) - | _ -> assert false (*Put a better error message here*)) - | (Nneg nmin, Nneg nmax) -> - ((match nmin.nexp with - | Nconst abs_min | N2n(_,Some abs_min) -> (minus_big_int abs_min) - | _ -> assert false (*Put a better error message here*)), - (match nmax.nexp with - | Nconst abs_max | N2n(_,Some abs_max) -> (minus_big_int abs_max) - | _ -> assert false (*Put a better error message here*))) - | _ -> assert false - in le_big_int min candidate && le_big_int candidate max - | _ -> assert false - -(*Rmove me when switch to zarith*) -let rec power_big_int b n = - if eq_big_int n zero_big_int - then unit_big_int - else mult_big_int b (power_big_int b (sub_big_int n unit_big_int)) - -let unpower_of_2 b = - let two = big_int_of_int 2 in - let four = big_int_of_int 4 in - let eight = big_int_of_int 8 in - let sixteen = big_int_of_int 16 in - let thirty_two = big_int_of_int 32 in - let sixty_four = big_int_of_int 64 in - let onetwentyeight = big_int_of_int 128 in - let twofiftysix = big_int_of_int 256 in - let fivetwelve = big_int_of_int 512 in - let oneotwentyfour = big_int_of_int 1024 in - let to_the_sixteen = big_int_of_int 65536 in - let to_the_thirtytwo = big_int_of_string "4294967296" in - let to_the_sixtyfour = big_int_of_string "18446744073709551616" in - let ck i = eq_big_int b i in - if ck unit_big_int then zero_big_int - else if ck two then unit_big_int - else if ck four then two - else if ck eight then big_int_of_int 3 - else if ck sixteen then four - else if ck thirty_two then big_int_of_int 5 - else if ck sixty_four then big_int_of_int 6 - else if ck onetwentyeight then big_int_of_int 7 - else if ck twofiftysix then eight - else if ck fivetwelve then big_int_of_int 9 - else if ck oneotwentyfour then big_int_of_int 10 - else if ck to_the_sixteen then sixteen - else if ck to_the_thirtytwo then thirty_two - else if ck to_the_sixtyfour then sixty_four - else let rec unpower b power = - if eq_big_int b unit_big_int - then power - else (unpower (div_big_int b two) (succ_big_int power)) in - unpower b zero_big_int - -let is_within_range candidate range constraints = - let candidate_actual = match candidate.t with - | Tabbrev(_,t) -> t - | _ -> candidate in - match candidate_actual.t with - | Tapp("atom", [TA_nexp n]) -> - (match n.nexp with - | Nconst i | N2n(_,Some i) -> if check_in_range i range then Yes else No - | _ -> Maybe) - | Tapp("range", [TA_nexp bot; TA_nexp top]) -> - (match bot.nexp,top.nexp with - | Nconst b, Nconst t | Nconst b, N2n(_,Some t) | N2n(_, Some b), Nconst t | N2n(_,Some b), N2n(_, Some t) -> - let at_least_in = check_in_range b range in - let at_most_in = check_in_range t range in - if at_least_in && at_most_in - then Yes - else if at_least_in || at_most_in - then Maybe - else No - | _ -> Maybe) - | Tapp("vector", [_; TA_nexp size ; _; _]) -> - (match size.nexp with - | Nconst i | N2n(_, Some i) -> - if check_in_range (power_big_int (big_int_of_int 2) i) range - then Yes - else No - | _ -> Maybe) - | _ -> Maybe - -let is_within_machine64 candidate constraints = is_within_range candidate int64_t constraints *) - (************************************************************************************************) (*FV finding analysis: identifies the free variables of a function, expression, etc *) @@ -313,9 +186,11 @@ let rec fv_of_exp consider_var bound used set (E_aux (e,(_,tannot))) : (Nameset. fv_of_exp consider_var bound u set e | E_app(id,es) -> let us = conditional_add_exp bound used id in + let us = conditional_add_exp bound us (prepend_id "val:" id) in list_fv bound us set es | E_app_infix(l,id,r) -> let us = conditional_add_exp bound used id in + let us = conditional_add_exp bound us (prepend_id "val:" id) in list_fv bound us set [l;r] | E_if(c,t,e) -> list_fv bound used set [c;t;e] | E_for(id,from,to_,by,_,body) -> @@ -462,8 +337,12 @@ let fv_of_fun consider_var (FD_aux (FD_function(rec_opt,tannot_opt,_,funcls),_) | [] -> failwith "fv_of_fun fell off the end looking for the function name" | FCL_aux(FCL_Funcl(id,_),_)::_ -> string_of_id id in let base_bounds = match rec_opt with + (* Current Sail does not have syntax for declaring functions as recursive, + and type checker does not check whether functions are recursive, so + just always add a self-dependency of functions on themselves | Rec_aux(Ast.Rec_rec,_) -> init_env fun_name - | _ -> mt in + | _ -> mt*) + | _ -> init_env fun_name in let base_bounds,ns_r = match tannot_opt with | Typ_annot_opt_aux(Typ_annot_opt_some (typq, typ),_) -> let bindings = if consider_var then typq_bindings typq else mt in @@ -575,7 +454,9 @@ let fv_of_def consider_var consider_scatter_as_one all_defs = function | DEF_val lebind -> ((fun (b,u,_) -> (b,u)) (fv_of_let consider_var mt mt mt lebind)) | DEF_spec vspec -> fv_of_vspec consider_var vspec | DEF_fixity _ -> mt,mt - | DEF_overload (id,ids) -> init_env (string_of_id id), List.fold_left (fun ns id -> Nameset.add (string_of_id id) ns) mt ids + | DEF_overload (id,ids) -> + init_env (string_of_id id), + List.fold_left (fun ns id -> Nameset.add ("val:" ^ string_of_id id) ns) mt ids | DEF_default def -> mt,mt | DEF_scattered sdef -> fv_of_scattered consider_var consider_scatter_as_one all_defs sdef | DEF_reg_dec rdec -> fv_of_rd consider_var rdec @@ -584,129 +465,133 @@ let fv_of_def consider_var consider_scatter_as_one all_defs = function let group_defs consider_scatter_as_one (Ast.Defs defs) = List.map (fun d -> (fv_of_def false consider_scatter_as_one defs d,d)) defs -(******************************************************************************* - * Reorder defs take 2 -*) -(*remove all of ns1 instances from ns2*) -let remove_all ns1 ns2 = - List.fold_right Nameset.remove (Nameset.elements ns1) ns2 - -let remove_from_all_uses bs dbts = - List.map (fun ((b,uses),d) -> (b,remove_all bs uses),d) dbts - -let remove_local_or_lib_vars dbts = - let bound_in_dbts = List.fold_right (fun ((b,_),_) bounds -> Nameset.union b bounds) dbts mt in - let is_bound_in_defs s = Nameset.mem s bound_in_dbts in - let rec remove_from_uses = function - | [] -> [] - | ((b,uses),d)::defs -> - ((b,(Nameset.filter is_bound_in_defs uses)),d)::remove_from_uses defs in - remove_from_uses dbts +(* + * Sorting definitions, take 3 + *) + +module Namemap = Map.Make(String) +(* Nodes are labeled with strings. A graph is represented as a map associating + each node with its sucessors *) +type graph = Nameset.t Namemap.t +type node_idx = { index : int; root : int } + +(* Find strongly connected components using Tarjan's algorithm. + This algorithm also returns a topological sorting of the graph components. *) +let scc ?(original_order : string list option) (g : graph) = + let components = ref [] in + let index = ref 0 in + + let stack = ref [] in + let push v = (stack := v :: !stack) in + let pop () = + begin + let v = List.hd !stack in + stack := List.tl !stack; + v + end + in -let compare_dbts ((_,u1),_) ((_,u2),_) = Pervasives.compare (Nameset.cardinal u1) (Nameset.cardinal u2) + let node_indices = Hashtbl.create (Namemap.cardinal g) in + let get_index v = (Hashtbl.find node_indices v).index in + let get_root v = (Hashtbl.find node_indices v).root in + let set_root v r = + Hashtbl.replace node_indices v { (Hashtbl.find node_indices v) with root = r } in + + let rec visit_node v = + begin + Hashtbl.add node_indices v { index = !index; root = !index }; + index := !index + 1; + push v; + if Namemap.mem v g then Nameset.iter (visit_edge v) (Namemap.find v g); + if get_root v = get_index v then (* v is the root of a SCC *) + begin + let component = ref [] in + let finished = ref false in + while not !finished do + let w = pop () in + component := w :: !component; + if String.compare v w = 0 then finished := true; + done; + components := !component :: !components; + end + end + and visit_edge v w = + if not (Hashtbl.mem node_indices w) then + begin + visit_node w; + if Hashtbl.mem node_indices w then set_root v (min (get_root v) (get_root w)); + end else begin + if List.mem w !stack then set_root v (min (get_root v) (get_index w)) + end + in -let rec print_dependencies orig_queue work_queue names = - match work_queue with - | [] -> () - | ((binds,uses),_)::wq -> - (if not(Nameset.is_empty(Nameset.inter names binds)) - then ((Printf.eprintf "binds of %s has uses of %s\n" (set_to_string binds) (set_to_string uses)); - print_dependencies orig_queue orig_queue uses)); - print_dependencies orig_queue wq names - -let merge_mutrecs defs = - let merge_aux ((binds', uses'), def) ((binds, uses), fundefs) = - let fundefs = match def with - | DEF_fundef fundef -> fundef :: fundefs - | DEF_internal_mutrec fundefs' -> fundefs' @ fundefs - | _ -> - (* let _ = Pretty_print_sail2.pp_defs stderr (Defs [def]) in *) - raise (Reporting_basic.err_unreachable (def_loc def) - "Trying to merge non-function definition with mutually recursive functions") in - (* let _ = Printf.eprintf " - Merging %s (using %s)\n" (set_to_string binds') (set_to_string uses') in *) - ((Nameset.union binds' binds, Nameset.union uses' uses), fundefs) in - let ((binds, uses), fundefs) = List.fold_right merge_aux defs ((mt, mt), []) in - ((binds, uses), DEF_internal_mutrec fundefs) - -let rec topological_sort work_queue defs = - match work_queue with - | [] -> List.rev defs - | ((binds,uses),def)::wq -> - (*Assumes work queue given in sorted order, invariant mantained on appropriate recursive calls*) - if (Nameset.cardinal uses = 0) - then (*let _ = Printf.eprintf "Adding def that binds %s to definitions\n" (set_to_string binds) in*) - topological_sort (remove_from_all_uses binds wq) (def::defs) - else if not(Nameset.is_empty(Nameset.inter binds uses)) - then topological_sort (((binds,(remove_all binds uses)),def)::wq) defs - else - match List.stable_sort compare_dbts work_queue with (*We wait to sort until there are no 0 dependency nodes on top*) - | [] -> failwith "sort shrunk the list???" - | (((n,uses),def)::rest) as wq -> - if (Nameset.cardinal uses = 0) - then topological_sort wq defs - else - let _ = Printf.eprintf "Merging (potentially) mutually recursive definitions %s and %s\n" (set_to_string n) (set_to_string uses) in - let is_used ((binds', uses'), def') = not(Nameset.is_empty(Nameset.inter binds' uses)) in - let (used, rest) = List.partition is_used rest in - let wq = merge_mutrecs (((n,uses),def)::used) :: rest in - topological_sort wq defs - -let rec add_to_partial_order ((binds,uses),def) = function - | [] -> -(* let _ = Printf.eprintf "add_to_partial_order for def with bindings %s, uses %s.\n Eol case.\n" (set_to_string binds) (set_to_string uses) in*) - [(binds,uses),def] - | (((bf,uf),deff)::defs as full_defs) -> - (*let _ = Printf.eprintf "add_to_partial_order for def with bindings %s, uses %s.\n None eol case. With first def binding %s, uses %s\n" (set_to_string binds) (set_to_string uses) (set_to_string bf) (set_to_string uf) in*) - if Nameset.is_empty uses - then ((binds,uses),def)::full_defs - else if Nameset.subset binds uf (*deff relies on def, so def must be defined first*) - then ((binds,uses),def)::((bf,(remove_all binds uf)),deff)::defs - else if Nameset.subset bf uses (*def relies at least on deff, but maybe more, push in*) - then ((bf,uf),deff)::(add_to_partial_order ((binds,(remove_all bf uses)),def) defs) - else (*These two are unrelated but new def might need to go further in*) - ((bf,uf),deff)::(add_to_partial_order ((binds,uses),def) defs) - -let rec gather_defs name already_included def_bind_triples = - match def_bind_triples with - | [] -> [],already_included,mt - | ((binds,uses),def)::def_bind_triples -> - let (defs,already_included,requires) = gather_defs name already_included def_bind_triples in - let bound_names = Nameset.elements binds in - if List.mem name already_included || List.exists (fun b -> List.mem b already_included) bound_names - then (defs,already_included,requires) - else - let uses = List.fold_right Nameset.remove already_included uses in - if Nameset.mem name binds - then (def::defs,(bound_names@already_included), Nameset.remove name (Nameset.union uses requires)) - else (defs,already_included,requires) - -let rec gather_all names already_included def_bind_triples = - let rec gather ns already_included defs reqs = match ns with - | [] -> defs,already_included,reqs - | name::ns -> - if List.mem name already_included - then gather ns already_included defs (Nameset.remove name reqs) - else - let (new_defs,already_included,new_reqs) = gather_defs name already_included def_bind_triples in - gather ns already_included (new_defs@defs) (Nameset.remove name (Nameset.union new_reqs reqs)) + let nodes = match original_order with + | Some nodes -> nodes + | None -> List.map fst (Namemap.bindings g) in - let (defs,already_included,reqs) = gather names already_included [] mt in - if Nameset.is_empty reqs - then defs - else (gather_all (Nameset.elements reqs) already_included def_bind_triples)@defs - -let restrict_defs defs name_list = - let defsno = gather_all name_list [] (group_defs false defs) in - let rdbts = group_defs true (Defs defsno) in - (*let partial_order = - List.fold_left (fun po d -> add_to_partial_order d po) [] rdbts in - let defs = List.map snd partial_order in*) - let defs = topological_sort (List.sort compare_dbts (remove_local_or_lib_vars rdbts)) [] in - Defs defs - - -let top_sort_defs defs = - let rdbts = group_defs true defs in - let defs = topological_sort (List.stable_sort compare_dbts (remove_local_or_lib_vars rdbts)) [] in - Defs defs + List.iter (fun v -> if not (Hashtbl.mem node_indices v) then visit_node v) nodes; + List.rev !components + +let add_def_to_graph (prelude, original_order, defset, graph) d = + let bound, used = fv_of_def false true [] d in + try + (* A definition may bind multiple identifiers, e.g. "let (x, y) = ...". + We add all identifiers to the dependency graph as a cycle. The actual + definition is attached to only one of the identifiers, so it will not + be duplicated in the final output. *) + let id = Nameset.choose bound in + let other_ids = Nameset.remove id bound in + let graph_id = Namemap.add id (Nameset.union used other_ids) graph in + let add_other_node id' g = Namemap.add id' (Nameset.singleton id) g in + prelude, + original_order @ [id], + Namemap.add id d defset, + Nameset.fold add_other_node other_ids graph_id + with + | Not_found -> + (* Some definitions do not bind any identifiers at all. This *should* + only happen for default bitvector order declarations, operator fixity + declarations, and comments. The sorting does not (currently) attempt + to preserve the positions of these AST nodes; they are collected + separately and placed at the beginning of the output. Comments are + currently ignored by the Lem and OCaml backends, anyway. For + default order and fixity declarations, this means that specifications + currently have to assume those declarations are moved to the + beginning when using a backend that requires topological sorting. *) + prelude @ [d], original_order, defset, graph + +let print_dot graph component : unit = + match component with + | root :: _ -> + print_endline ("// Dependency cycle including " ^ root); + print_endline ("digraph cycle_" ^ root ^ " {"); + List.iter (fun caller -> + let print_edge callee = print_endline (" \"" ^ caller ^ "\" -> \"" ^ callee ^ "\";") in + Namemap.find caller graph + |> Nameset.filter (fun id -> List.mem id component) + |> Nameset.iter print_edge) component; + print_endline "}" + | [] -> () + +let def_of_component graph defset comp = + let get_def id = if Namemap.mem id defset then [Namemap.find id defset] else [] in + match List.concat (List.map get_def comp) with + | [] -> [] + | [def] -> [def] + | (def :: _) as defs -> + let get_fundefs = function + | DEF_fundef fundef -> [fundef] + | DEF_internal_mutrec fundefs -> fundefs + | _ -> + raise (Reporting_basic.err_unreachable (def_loc def) + "Trying to merge non-function definition with mutually recursive functions") in + let fundefs = List.concat (List.map get_fundefs defs) in + print_dot graph (List.map (fun fd -> string_of_id (id_of_fundef fd)) fundefs); + [DEF_internal_mutrec fundefs] + +let top_sort_defs (Defs defs) = + let prelude, original_order, defset, graph = + List.fold_left add_def_to_graph ([], [], Namemap.empty, Namemap.empty) defs in + let components = scc ~original_order:original_order graph in + Defs (prelude @ List.concat (List.map (def_of_component graph defset) components)) -- cgit v1.2.3 From 3216b9307830895b8c76725d8aea8936a0aca181 Mon Sep 17 00:00:00 2001 From: Thomas Bauereiss Date: Wed, 17 Jan 2018 15:08:28 +0000 Subject: Try to remove early returns more aggressively In particular, there is an ASL pattern with single-branch if-expressions containing an early return (and an empty else-branch), e.g. { ... if (error) then return(Fault) else (); ... return(Success); } The rewriting pass now tries to fold the rest of the block into the else-branch, since it is unreachable after the then-branch, e.g. { ... if (error) then return(Fault) else { ... return(Success); } } In combination with other rewriting rules, this allows the rewriting pass to pull out the early returns if *all* branches end in a return statement. If it is possible to pull out the return all the way, i.e., so that the function body is a single return statement, then the return can be removed. If that is not possible, then the function body is left as it was originally. --- src/rewrites.ml | 63 ++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 49 insertions(+), 14 deletions(-) (limited to 'src') diff --git a/src/rewrites.ml b/src/rewrites.ml index 7e852092..86560415 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1682,7 +1682,33 @@ let rewrite_defs_early_return (Defs defs) = | E_return e -> e | _ -> exp in - let e_block es = + let e_if (e1, e2, e3) = + if is_return e2 && is_return e3 then + let (E_aux (_, annot)) = get_return e2 in + E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot)) + else E_if (e1, e2, e3) in + + let rec e_block es = + (* If one of the branches of an if-expression in a block is an early + return, fold the rest of the block after the if-expression into the + other branch *) + let fold_if_return exp block = match exp with + | E_aux (E_if (c, t, e), annot) when is_return t -> + let annot = match block with + | [] -> annot + | _ -> let (E_aux (_, annot)) = Util.last block in annot + in + let e' = E_aux (e_block (e :: block), annot) in + [E_aux (e_if (c, t, e'), annot)] + | E_aux (E_if (c, t, e), annot) when is_return e -> + let annot = match block with + | [] -> annot + | _ -> let (E_aux (_, annot)) = Util.last block in annot + in + let t' = E_aux (e_block (t :: block), annot) in + [E_aux (e_if (c, t', e), annot)] + | _ -> exp :: block in + let es = List.fold_right fold_if_return es [] in match es with | [E_aux (e, _)] -> e | _ :: _ when is_return (Util.last es) -> @@ -1690,12 +1716,6 @@ let rewrite_defs_early_return (Defs defs) = E_return (E_aux (E_block (Util.butlast es @ [get_return e]), annot)) | _ -> E_block es in - let e_if (e1, e2, e3) = - if is_return e2 && is_return e3 then - let (E_aux (_, annot)) = get_return e2 in - E_return (E_aux (E_if (e1, get_return e2, get_return e3), annot)) - else E_if (e1, e2, e3) in - let e_case (e, pes) = let is_return_pexp (Pat_aux (pexp, _)) = match pexp with | Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in @@ -1710,6 +1730,17 @@ let rewrite_defs_early_return (Defs defs) = then E_return (E_aux (E_case (e, List.map get_return_pexp pes), annot)) else E_case (e, pes) in + let e_let (lb, exp) = + let (E_aux (_, annot) as ret_exp) = get_return exp in + if is_return exp then E_return (E_aux (E_let (lb, ret_exp), annot)) + else E_let (lb, exp) in + + let e_internal_let (pat, exp1, exp2) = + let (E_aux (_, annot) as ret_exp2) = get_return exp2 in + if is_return exp2 then + E_return (E_aux (E_internal_let (pat, exp1, ret_exp2), annot)) + else E_internal_let (pat, exp1, exp2) in + let e_aux (exp, (l, annot)) = let full_exp = propagate_exp_effect (E_aux (exp, (l, annot))) in let env = env_of full_exp in @@ -1724,14 +1755,18 @@ let rewrite_defs_early_return (Defs defs) = let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pexp), a)) = let pat,guard,exp,pannot = destruct_pexp pexp in + (* Try to pull out early returns as far as possible *) + let exp' = + fold_exp + { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case; + e_let = e_let; e_internal_let = e_internal_let } + exp in + (* Remove early return if we can pull it out completely, and rewrite + remaining early returns to "early_return" calls *) let exp = - exp - (* Pull early returns out as far as possible *) - |> fold_exp { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case } - (* Remove singleton E_return *) - |> get_return - (* Fix effect annotations *) - |> fold_exp { id_exp_alg with e_aux = e_aux } in + fold_exp + { id_exp_alg with e_aux = e_aux } + (if is_return exp' then get_return exp' else exp) in let a = match a with | (l, Some (env, typ, eff)) -> (l, Some (env, typ, union_effects eff (effect_of exp))) -- cgit v1.2.3 From 373b081bc4b9669bbc17accf24e0dd392489f762 Mon Sep 17 00:00:00 2001 From: Thomas Bauereiss Date: Wed, 17 Jan 2018 19:41:25 +0000 Subject: Use right effect annotations in early return rewriting Also drop redundant unit expressions when concatenating with an empty block. --- src/rewrites.ml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/rewrites.ml b/src/rewrites.ml index 86560415..a143175d 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1674,6 +1674,10 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base rewriting of early returns *) let rewrite_defs_early_return (Defs defs) = + let is_unit (E_aux (exp, _)) = match exp with + | E_lit (L_aux (L_unit, _)) -> true + | _ -> false in + let is_return (E_aux (exp, _)) = match exp with | E_return _ -> true | _ -> false in @@ -1693,19 +1697,21 @@ let rewrite_defs_early_return (Defs defs) = return, fold the rest of the block after the if-expression into the other branch *) let fold_if_return exp block = match exp with - | E_aux (E_if (c, t, e), annot) when is_return t -> + | E_aux (E_if (c, t, (E_aux (_, annot) as e)), _) when is_return t -> let annot = match block with | [] -> annot | _ -> let (E_aux (_, annot)) = Util.last block in annot in - let e' = E_aux (e_block (e :: block), annot) in + let block = if is_unit e then block else e :: block in + let e' = E_aux (e_block block, annot) in [E_aux (e_if (c, t, e'), annot)] - | E_aux (E_if (c, t, e), annot) when is_return e -> + | E_aux (E_if (c, (E_aux (_, annot) as t), e), _) when is_return e -> let annot = match block with | [] -> annot | _ -> let (E_aux (_, annot)) = Util.last block in annot in - let t' = E_aux (e_block (t :: block), annot) in + let block = if is_unit t then block else t :: block in + let t' = E_aux (e_block block, annot) in [E_aux (e_if (c, t', e), annot)] | _ -> exp :: block in let es = List.fold_right fold_if_return es [] in -- cgit v1.2.3