From e4fce3ffd02b69e36b42ffe3c868570c45aef986 Mon Sep 17 00:00:00 2001 From: Thomas Bauereiss Date: Thu, 10 Aug 2017 17:42:27 +0100 Subject: Add support for early return to Lem backend Implemented using the exception monad, by throwing and catching the return value --- src/gen_lib/sail_values.lem | 2 +- src/gen_lib/state.lem | 45 +++++++++++++++---- src/pretty_print_lem.ml | 103 ++++++++++++++++++++++++++------------------ src/rewriter.ml | 66 ++++++++++++++++++++++++++-- src/rewriter.mli | 7 +++ 5 files changed, 167 insertions(+), 56 deletions(-) diff --git a/src/gen_lib/sail_values.lem b/src/gen_lib/sail_values.lem index f1541466..b4a15432 100644 --- a/src/gen_lib/sail_values.lem +++ b/src/gen_lib/sail_values.lem @@ -1028,7 +1028,7 @@ let assert' b msg_opt = | Just msg -> msg | Nothing -> "unspecified error" end in - if bitU_to_bool b then () else failwith msg + if b then () else failwith msg (* convert numbers unsafely to naturals *) diff --git a/src/gen_lib/state.lem b/src/gen_lib/state.lem index 1bc1ad55..2e11e8a9 100644 --- a/src/gen_lib/state.lem +++ b/src/gen_lib/state.lem @@ -14,12 +14,28 @@ type sequential_state = <| regstate : regstate; write_ea : maybe (write_kind * integer * integer); last_exclusive_operation_was_load : bool|> -type M 'a = sequential_state -> list ((either 'a string) * sequential_state) +(* State, nondeterminism and exception monad with result type 'a + and exception type 'e. *) +type ME 'a 'e = sequential_state -> list ((either 'a 'e) * sequential_state) -val return : forall 'a. 'a -> M 'a +(* Most of the time, we don't distinguish between different types of exceptions *) +type M 'a = ME 'a unit + +(* For early return, we abuse exceptions by throwing and catching + the return value. The exception type is "maybe 'r", where "Nothing" + represents a proper exception and "Just r" an early return of value "r". *) +type MR 'a 'r = ME 'a (maybe 'r) + +val liftR : forall 'a 'r. M 'a -> MR 'a 'r +let liftR m s = List.map (function + | (Left a, s') -> (Left a, s') + | (Right (), s') -> (Right Nothing, s') + end) (m s) + +val return : forall 'a 'e. 'a -> ME 'a 'e let return a s = [(Left a,s)] -val bind : forall 'a 'b. M 'a -> ('a -> M 'b) -> M 'b +val bind : forall 'a 'b 'e. ME 'a 'e -> ('a -> ME 'b 'e) -> ME 'b 'e let bind m f (s : sequential_state) = List.concatMap (function | (Left a, s') -> f a s' @@ -27,12 +43,23 @@ let bind m f (s : sequential_state) = end) (m s) let inline (>>=) = bind -val (>>): forall 'b. M unit -> M 'b -> M 'b +val (>>): forall 'b 'e. ME unit 'e -> ME 'b 'e -> ME 'b 'e let inline (>>) m n = m >>= fun _ -> n val exit : forall 'e 'a. 'e -> M 'a -let exit _ s = [(Right "exit",s)] +let exit _ s = [(Right (), s)] + +val early_return : forall 'r. 'r -> MR unit 'r +let early_return r s = [(Right (Just r), s)] +val catch_early_return : forall 'a 'r. MR 'a 'a -> M 'a +let catch_early_return m s = + List.map + (function + | (Right (Just a), s') -> (Left a, s') + | (Right Nothing, s') -> (Right (), s') + | (Left a, s') -> (Left a, s') + end) (m s) val range : integer -> integer -> list integer let rec range i j = @@ -174,8 +201,8 @@ val footprint : M unit let footprint = return () -val foreachM_inc : forall 'vars. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> M 'vars) -> M 'vars +val foreachM_inc : forall 'vars 'e. (integer * integer * integer) -> 'vars -> + (integer -> 'vars -> ME 'vars 'e) -> ME 'vars 'e let rec foreachM_inc (i,stop,by) vars body = if i <= stop then @@ -184,8 +211,8 @@ let rec foreachM_inc (i,stop,by) vars body = else return vars -val foreachM_dec : forall 'vars. (integer * integer * integer) -> 'vars -> - (integer -> 'vars -> M 'vars) -> M 'vars +val foreachM_dec : forall 'vars 'e. (integer * integer * integer) -> 'vars -> + (integer -> 'vars -> ME 'vars 'e) -> ME 'vars 'e let rec foreachM_dec (i,stop,by) vars body = if i >= stop then diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 5fe6b69d..586773ca 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -170,6 +170,7 @@ let doc_typ_lem, doc_atomic_typ_lem = | Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) -> (match simplify_nexp m with | (Nexp_aux(Nexp_constant i,_)) -> string "bitvector ty" ^^ doc_int i + | (Nexp_aux(Nexp_var _, _)) -> separate space [string "bitvector"; doc_nexp m] | _ -> raise (Reporting_basic.err_unreachable l "cannot pretty-print bitvector type with non-constant length")) | _ -> string "vector" ^^ space ^^ typ regtypes elem_typ in @@ -340,13 +341,23 @@ and contains_bitvector_typ_arg (Typ_arg_aux (targ, _)) = match targ with | Typ_arg_typ t -> contains_bitvector_typ t | _ -> false +let contains_early_return exp = + fst (fold_exp + { (Rewriter.compute_exp_alg false (||)) + with e_return = (fun (_, r) -> (true, E_return r)) } exp) + let prefix_recordtype = true let report = Reporting_basic.err_unreachable let doc_exp_lem, doc_let_lem = - let rec top_exp regtypes (aexp_needed : bool) (E_aux (e, (l,annot)) as full_exp) = - let expY = top_exp regtypes true in - let expN = top_exp regtypes false in - let expV = top_exp regtypes in + let rec top_exp regtypes (early_ret : bool) (aexp_needed : bool) + (E_aux (e, (l,annot)) as full_exp) = + let expY = top_exp regtypes early_ret true in + let expN = top_exp regtypes early_ret false in + let expV = top_exp regtypes early_ret in + let liftR doc = + if early_ret && effectful (effect_of full_exp) + then separate space [string "liftR"; parens (doc)] + else doc in match e with | E_assign((LEXP_aux(le_act,tannot) as le), e) -> (* can only be register writes *) @@ -358,14 +369,14 @@ let doc_exp_lem, doc_let_lem = if is_bit_typ (typ_of_annot lannot) then raise (report l "indexing a register's (single bit) bitfield not supported") else - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_field_range") - (align (doc_lexp_deref_lem regtypes le ^^ space^^ - string_lit (doc_id_lem id) ^/^ expY e2 ^/^ expY e3 ^/^ expY e)) + (align (doc_lexp_deref_lem regtypes early_ret le ^^ space^^ + string_lit (doc_id_lem id) ^/^ expY e2 ^/^ expY e3 ^/^ expY e))) | _ -> - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_range") - (align (doc_lexp_deref_lem regtypes le ^^ space ^^ expY e2 ^/^ expY e3 ^/^ expY e)) + (align (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ expY e2 ^/^ expY e3 ^/^ expY e))) ) | LEXP_vector (le,e2) when is_bit_typ t -> (match le with @@ -373,23 +384,23 @@ let doc_exp_lem, doc_let_lem = if is_bit_typ (typ_of_annot lannot) then raise (report l "indexing a register's (single bit) bitfield not supported") else - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_field_bit") - (align (doc_lexp_deref_lem regtypes le ^^ space ^^ doc_id_lem id ^/^ expY e2 ^/^ expY e)) + (align (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ doc_id_lem id ^/^ expY e2 ^/^ expY e))) | _ -> - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_bit") - (doc_lexp_deref_lem regtypes le ^^ space ^^ expY e2 ^/^ expY e) + (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ expY e2 ^/^ expY e)) ) | LEXP_field (le,id) when is_bit_typ t -> - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_bitfield") - (doc_lexp_deref_lem regtypes le ^^ space ^^ string_lit(doc_id_lem id) ^/^ expY e) + (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ string_lit(doc_id_lem id) ^/^ expY e)) | LEXP_field (le,id) -> - (prefix 2 1) + liftR ((prefix 2 1) (string "write_reg_field") - (doc_lexp_deref_lem regtypes le ^^ space ^^ - string_lit(doc_id_lem id) ^/^ expY e) + (doc_lexp_deref_lem regtypes early_ret le ^^ space ^^ + string_lit(doc_id_lem id) ^/^ expY e)) (* | (LEXP_id id | LEXP_cast (_,id)), t, Alias alias_info -> (match alias_info with | Alias_field(reg,field) -> @@ -404,7 +415,7 @@ let doc_exp_lem, doc_let_lem = string "write_two_regs" ^^ space ^^ string reg1 ^^ space ^^ string reg2 ^^ space ^^ expY e) *) | _ -> - (prefix 2 1) (string "write_reg") (doc_lexp_deref_lem regtypes le ^/^ expY e)) + liftR ((prefix 2 1) (string "write_reg") (doc_lexp_deref_lem regtypes early_ret le ^/^ expY e))) | E_vector_append(le,re) -> raise (Reporting_basic.err_unreachable l "E_vector_access should have been rewritten before pretty-printing") @@ -430,7 +441,7 @@ let doc_exp_lem, doc_let_lem = | E_for(id,exp1,exp2,exp3,(Ord_aux(order,_)),exp4) -> raise (report l "E_for should have been removed till now") | E_let(leb,e) -> - let epp = let_exp regtypes leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in + let epp = let_exp regtypes early_ret leb ^^ space ^^ string "in" ^^ hardline ^^ expN e in if aexp_needed then parens epp else epp | E_app(f,args) -> begin match f with @@ -509,7 +520,7 @@ let doc_exp_lem, doc_let_lem = not (contains_t_pp_var t) then (align epp ^^ (doc_tannot_lem regtypes (effectful eff) t), true) else (epp, aexp_needed) in - if aexp_needed then parens (align taepp) else taepp + liftR (if aexp_needed then parens (align taepp) else taepp) end end | E_vector_access (v,e) -> @@ -578,8 +589,8 @@ let doc_exp_lem, doc_let_lem = if has_effect eff BE_rreg then let epp = separate space [string "read_reg";doc_id_lem id] in if is_bitvector_typ base_typ && not (contains_t_pp_var base_typ) - then parens (epp ^^ doc_tannot_lem regtypes true base_typ) - else epp + then liftR (parens (epp ^^ doc_tannot_lem regtypes true base_typ)) + else liftR epp else if is_ctor env id then doc_id_lem_ctor id else doc_id_lem id (*| Base((_,t),Alias alias_info,_,eff,_,_) -> @@ -656,14 +667,14 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (report l "cannot get record type") in let epp = anglebars (space ^^ (align (separate_map (semi_sp ^^ break 1) - (doc_fexp regtypes recordtyp) fexps)) ^^ space) in + (doc_fexp regtypes early_ret recordtyp) fexps)) ^^ space) in if aexp_needed then parens epp else epp | E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) -> let recordtyp = match annot with | Some (env, Typ_aux (Typ_id tid,_), _) when Env.is_record tid env -> tid | _ -> raise (report l "cannot get record type") in - anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp regtypes recordtyp) fexps)) + anglebars (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp regtypes early_ret recordtyp) fexps)) | E_vector exps -> let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let (start, len, order, etyp) = @@ -792,10 +803,10 @@ let doc_exp_lem, doc_let_lem = pattern-matching on integers *) let epp = group ((separate space [string "match"; only_integers e; string "with"]) ^/^ - (separate_map (break 1) (doc_case regtypes) pexps) ^/^ + (separate_map (break 1) (doc_case regtypes early_ret) pexps) ^/^ (string "end")) in if aexp_needed then parens (align epp) else align epp - | E_exit e -> separate space [string "exit"; expY e;] + | E_exit e -> liftR (separate space [string "exit"; expY e;]) | E_assert (e1,e2) -> let epp = separate space [string "assert'"; expY e1; expY e2] in if aexp_needed then parens (align epp) else align epp @@ -917,41 +928,40 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (Reporting_basic.err_unreachable l "pretty-printing non-constant sizeof expressions to Lem not supported")) - | E_return _ -> - raise (Reporting_basic.err_todo l - "pretty-printing early return statements to Lem not yet supported") + | E_return r -> + align (string "early_return" ^//^ expV true r) | E_constraint _ | E_comment _ | E_comment_struc _ -> empty | E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_exp_user _ -> raise (Reporting_basic.err_unreachable l "unsupported internal expression encountered while pretty-printing") - and let_exp regtypes (LB_aux(lb,_)) = match lb with + and let_exp regtypes early_ret (LB_aux(lb,_)) = match lb with | LB_val_explicit(_,pat,e) | LB_val_implicit(pat,e) -> prefix 2 1 (separate space [string "let"; doc_pat_lem regtypes true pat; equals]) - (top_exp regtypes false e) + (top_exp regtypes early_ret false e) - and doc_fexp regtypes recordtyp (FE_aux(FE_Fexp(id,e),_)) = + and doc_fexp regtypes early_ret recordtyp (FE_aux(FE_Fexp(id,e),_)) = let fname = if prefix_recordtype then (string (string_of_id recordtyp ^ "_")) ^^ doc_id_lem id else doc_id_lem id in - group (doc_op equals fname (top_exp regtypes true e)) + group (doc_op equals fname (top_exp regtypes early_ret true e)) - and doc_case regtypes = function + and doc_case regtypes early_ret = function | Pat_aux(Pat_exp(pat,e),_) -> group (prefix 3 1 (separate space [pipe; doc_pat_lem regtypes false pat;arrow]) - (group (top_exp regtypes false e))) + (group (top_exp regtypes early_ret false e))) | Pat_aux(Pat_when(_,_,_),(l,_)) -> raise (Reporting_basic.err_unreachable l "guarded pattern expression should have been rewritten before pretty-printing") - and doc_lexp_deref_lem regtypes ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with + and doc_lexp_deref_lem regtypes early_ret ((LEXP_aux(lexp,(l,annot))) as le) = match lexp with | LEXP_field (le,id) -> - parens (separate empty [doc_lexp_deref_lem regtypes le;dot;doc_id_lem id]) + parens (separate empty [doc_lexp_deref_lem regtypes early_ret le;dot;doc_id_lem id]) | LEXP_vector(le,e) -> - parens ((separate space) [string "access";doc_lexp_deref_lem regtypes le; - top_exp regtypes true e]) + parens ((separate space) [string "access";doc_lexp_deref_lem regtypes early_ret le; + top_exp regtypes early_ret true e]) | LEXP_id id -> doc_id_lem id | LEXP_cast (typ,id) -> doc_id_lem id | _ -> @@ -1225,9 +1235,16 @@ let doc_rec_lem (Rec_aux(r,_)) = match r with let doc_tannot_opt_lem regtypes (Typ_annot_opt_aux(t,_)) = match t with | Typ_annot_opt_some(tq,typ) -> (*doc_typquant_lem tq*) (doc_typ_lem regtypes typ) +let doc_fun_body_lem regtypes exp = + let early_ret = contains_early_return exp in + let doc_exp = doc_exp_lem regtypes early_ret false exp in + if early_ret + then align (string "catch_early_return" ^//^ parens (doc_exp)) + else doc_exp + let doc_funcl_lem regtypes (FCL_aux(FCL_Funcl(id,pat,exp),_)) = group (prefix 3 1 ((doc_pat_lem regtypes false pat) ^^ space ^^ arrow) - (doc_exp_lem regtypes false exp)) + (doc_fun_body_lem regtypes exp)) let get_id = function | [] -> failwith "FD_function with empty list" @@ -1244,7 +1261,7 @@ let rec doc_fundef_lem regtypes (FD_aux(FD_function(r, typa, efa, fcls),fannot)) [(string "let") ^^ (doc_rec_lem r) ^^ (doc_id_lem id); (doc_pat_lem regtypes true pat); equals]) - (doc_exp_lem regtypes false exp) + (doc_fun_body_lem regtypes exp) | _ -> let id = get_id fcls in (* let sep = hardline ^^ pipe ^^ space in *) @@ -1358,7 +1375,7 @@ let rec doc_def_lem regtypes def = match def with | DEF_default df -> (empty,empty) | DEF_fundef f_def -> (empty,group (doc_fundef_lem regtypes f_def) ^/^ hardline) - | DEF_val lbind -> (empty,group (doc_let_lem regtypes lbind) ^/^ hardline) + | DEF_val lbind -> (empty,group (doc_let_lem regtypes false lbind) ^/^ hardline) | DEF_scattered sdef -> failwith "doc_def_lem: shoulnd't have DEF_scattered at this point" | DEF_kind _ -> (empty,empty) diff --git a/src/rewriter.ml b/src/rewriter.ml index 1f8452ba..8da8aacf 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -149,8 +149,8 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match snd annot with List.fold_left union_effects (effect_of e) (List.map effect_of_pexp pexps) | E_let (lb,e) -> union_effects (effect_of_lb lb) (effect_of e) | E_assign (lexp,e) -> union_effects (effect_of_lexp lexp) (effect_of e) - | E_exit e -> effect_of e - | E_return e -> effect_of e + | E_exit e -> union_effects eff (effect_of e) + | E_return e -> union_effects eff (effect_of e) | E_sizeof _ | E_sizeof_internal _ | E_constraint _ -> no_effect | E_assert (c,m) -> eff | E_comment _ | E_comment_struc _ -> no_effect @@ -2074,7 +2074,7 @@ let rewrite_exp_lift_assign_intro rewriters ((E_aux (exp,((l,_) as annot))) as f | (E_aux(E_assign((LEXP_aux ((LEXP_id id | LEXP_cast (_,id)),_)) as le,e), ((l, Some (env,typ,eff)) as annot)) as exp)::exps -> (match Env.lookup_id id env with - | Unbound -> + | Unbound | Local _ -> let le' = rewriters.rewrite_lexp rewriters le in let e' = rewrite_base e in let exps' = walker exps in @@ -2211,6 +2211,65 @@ let rewrite_defs_separate_numbs defs = rewrite_defs_base rewrite_def = rewrite_def; rewrite_defs = rewrite_defs_base} defs*) +let rewrite_defs_early_return = + let is_return (E_aux (exp, _)) = match exp with + | E_return _ -> true + | _ -> false in + + let get_return (E_aux (e, (l, _)) as exp) = match e with + | E_return e -> e + | _ -> exp in + + let e_block es = + (* let rec walker = function + | e :: es -> if is_return e then [e] else e :: walker es + | [] -> [] in + let es = walker es in *) + match es with + | [E_aux (e, _)] -> e + | _ -> E_block es in + + let e_if (e1, e2, e3) = + if is_return e2 && is_return e3 then E_if (e1, get_return e2, get_return e3) + else E_if (e1, e2, e3) in + + let e_case (e, pes) = + let is_return_pexp (Pat_aux (pexp, _)) = match pexp with + | Pat_exp (_, e) | Pat_when (_, _, e) -> is_return e in + let get_return_pexp (Pat_aux (pexp, a)) = match pexp with + | Pat_exp (p, e) -> Pat_aux (Pat_exp (p, get_return e), a) + | Pat_when (p, g, e) -> Pat_aux (Pat_when (p, g, get_return e), a) in + if List.for_all is_return_pexp pes + then E_return (E_aux (E_case (e, List.map get_return_pexp pes), (Parse_ast.Unknown, None))) + else E_case (e, pes) in + + let e_aux (exp, (l, annot)) = + let full_exp = fix_eff_exp (E_aux (exp, (l, annot))) in + match annot with + | Some (env, typ, eff) when is_return full_exp -> + (* Add escape effect annotation, since we use the exception mechanism + of the state monad to implement early return in the Lem backend *) + let annot' = Some (env, typ, union_effects eff (mk_effect [BE_escape])) in + E_aux (exp, (l, annot')) + | _ -> full_exp in + + let rewrite_funcl_early_return _ (FCL_aux (FCL_Funcl (id, pat, exp), a)) = + let exp = fold_exp + { id_exp_alg with e_block = e_block; e_if = e_if; e_case = e_case; + e_aux = e_aux } exp in + let a = match a with + | (l, Some (env, typ, eff)) -> + (l, Some (env, typ, union_effects eff (effect_of exp))) + | _ -> a in + FCL_aux (FCL_Funcl (id, pat, get_return exp), a) in + + let rewrite_fun_early_return rewriters + (FD_aux (FD_function (rec_opt, tannot_opt, effect_opt, funcls), a)) = + FD_aux (FD_function (rec_opt, tannot_opt, effect_opt, + List.map (rewrite_funcl_early_return rewriters) funcls), a) in + + rewrite_defs_base { rewriters_base with rewrite_fun = rewrite_fun_early_return } + let rewrite_defs_ocaml = [ top_sort_defs; rewrite_defs_remove_vector_concat; @@ -3053,6 +3112,7 @@ let rewrite_defs_lem =[ rewrite_defs_remove_bitvector_pats; rewrite_defs_guarded_pats; (* recheck_defs; *) + rewrite_defs_early_return; rewrite_defs_exp_lift_assign; rewrite_defs_remove_blocks; rewrite_defs_letbind_effects; diff --git a/src/rewriter.mli b/src/rewriter.mli index 9c91b55b..9dbdee3d 100644 --- a/src/rewriter.mli +++ b/src/rewriter.mli @@ -154,3 +154,10 @@ val fold_exp : ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux,'fexps,'fexps_a 'pat,'pat_aux,'fpat,'fpat_aux) exp_alg -> 'a exp -> 'exp val id_pat_alg : ('a,'a pat, 'a pat_aux, 'a fpat, 'a fpat_aux) pat_alg + +val compute_exp_alg : 'b -> ('b -> 'b -> 'b) -> + ('a,('b * 'a exp),('b * 'a exp_aux),('b * 'a lexp),('b * 'a lexp_aux),('b * 'a fexp), + ('b * 'a fexp_aux),('b * 'a fexps),('b * 'a fexps_aux), + ('b * 'a opt_default_aux),('b * 'a opt_default),('b * 'a pexp),('b * 'a pexp_aux), + ('b * 'a letbind_aux),('b * 'a letbind), + ('b * 'a pat),('b * 'a pat_aux),('b * 'a fpat),('b * 'a fpat_aux)) exp_alg -- cgit v1.2.3