From 1f421b865a87a161a82550443a0cf39aa2642d9c Mon Sep 17 00:00:00 2001 From: Brian Campbell Date: Mon, 15 Apr 2019 12:08:28 +0100 Subject: Basic loop termination measures for Coq Currently only supports pure termination measures for loops with effects. The user syntax uses separate termination measure declarations, as in the previous recursive termination measures, which are rewritten into the loop AST nodes before type checking (because it would be rather difficult to calculate the correct environment to type check the separate declaration in). --- src/ast_util.ml | 33 +++++-- src/constant_propagation.ml | 10 +- src/initial_check.ml | 20 +++- src/interpreter.ml | 4 +- src/jib/anf.ml | 2 +- src/jib/jib_compile.ml | 1 + src/monomorphise.ml | 10 +- src/ocaml_backend.ml | 4 +- src/parse_ast.ml | 16 ++- src/parser.mly | 29 +++++- src/pretty_print_coq.ml | 95 +++++++++++------- src/pretty_print_lem.ml | 8 +- src/pretty_print_sail.ml | 21 +++- src/rewriter.ml | 235 +++++++++++++++++++++++++++++++++++++++++--- src/rewriter.mli | 6 +- src/rewrites.ml | 82 ++++++++++++++-- src/rewrites.mli | 3 + src/sail.ml | 3 + src/spec_analysis.ml | 14 ++- src/type_check.ml | 24 ++++- 20 files changed, 525 insertions(+), 95 deletions(-) (limited to 'src') diff --git a/src/ast_util.ml b/src/ast_util.ml index d0efc0de..254afccf 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -519,7 +519,7 @@ and map_exp_annot_aux f = function | E_tuple xs -> E_tuple (List.map (map_exp_annot f) xs) | E_if (cond, t, e) -> E_if (map_exp_annot f cond, map_exp_annot f t, map_exp_annot f e) | E_for (v, e1, e2, e3, o, e4) -> E_for (v, map_exp_annot f e1, map_exp_annot f e2, map_exp_annot f e3, o, map_exp_annot f e4) - | E_loop (loop_type, e1, e2) -> E_loop (loop_type, map_exp_annot f e1, map_exp_annot f e2) + | E_loop (loop_type, measure, e1, e2) -> E_loop (loop_type, map_measure_annot f measure, map_exp_annot f e1, map_exp_annot f e2) | E_vector exps -> E_vector (List.map (map_exp_annot f) exps) | E_vector_access (exp1, exp2) -> E_vector_access (map_exp_annot f exp1, map_exp_annot f exp2) | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (map_exp_annot f exp1, map_exp_annot f exp2, map_exp_annot f exp3) @@ -546,6 +546,10 @@ and map_exp_annot_aux f = function | E_var (lexp, exp1, exp2) -> E_var (map_lexp_annot f lexp, map_exp_annot f exp1, map_exp_annot f exp2) | E_internal_plet (pat, exp1, exp2) -> E_internal_plet (map_pat_annot f pat, map_exp_annot f exp1, map_exp_annot f exp2) | E_internal_return exp -> E_internal_return (map_exp_annot f exp) +and map_measure_annot f (Measure_aux (m, l)) = Measure_aux (map_measure_annot_aux f m, l) +and map_measure_annot_aux f = function + | Measure_none -> Measure_none + | Measure_some exp -> Measure_some (map_exp_annot f exp) and map_opt_default_annot f (Def_val_aux (df, annot)) = Def_val_aux (map_opt_default_annot_aux f df, f annot) and map_opt_default_annot_aux f = function | Def_val_empty -> Def_val_empty @@ -648,6 +652,7 @@ let def_loc = function | DEF_internal_mutrec _ -> Parse_ast.Unknown | DEF_pragma (_, _, l) -> l | DEF_measure (id, _, _) -> id_loc id + | DEF_loop_measures (id, _) -> id_loc id let string_of_id = function | Id_aux (Id v, _) -> v @@ -838,8 +843,8 @@ let rec string_of_exp (E_aux (exp, _)) = ^ " by " ^ string_of_exp u ^ " order " ^ string_of_order ord ^ ") { " ^ string_of_exp body - | E_loop (While, cond, body) -> "while " ^ string_of_exp cond ^ " do " ^ string_of_exp body - | E_loop (Until, cond, body) -> "repeat " ^ string_of_exp body ^ " until " ^ string_of_exp cond + | E_loop (While, measure, cond, body) -> "while " ^ string_of_measure measure ^ string_of_exp cond ^ " do " ^ string_of_exp body + | E_loop (Until, measure, cond, body) -> "repeat " ^ string_of_measure measure ^ string_of_exp body ^ " until " ^ string_of_exp cond | E_assert (test, msg) -> "assert(" ^ string_of_exp test ^ ", " ^ string_of_exp msg ^ ")" | E_exit exp -> "exit " ^ string_of_exp exp | E_throw exp -> "throw " ^ string_of_exp exp @@ -855,6 +860,11 @@ let rec string_of_exp (E_aux (exp, _)) = | E_nondet _ -> "NONDET" | E_internal_value _ -> "INTERNAL VALUE" +and string_of_measure (Measure_aux (m,_)) = + match m with + | Measure_none -> "" + | Measure_some exp -> "termination_measure { " ^ string_of_exp exp ^ "}" + and string_of_fexp (FE_aux (FE_Fexp (field, exp), _)) = string_of_id field ^ " = " ^ string_of_exp exp and string_of_pexp (Pat_aux (pexp, _)) = @@ -1448,8 +1458,8 @@ let rec subst id value (E_aux (e_aux, annot) as exp) = | E_if (cond, then_exp, else_exp) -> E_if (subst id value cond, subst id value then_exp, subst id value else_exp) - | E_loop (loop, cond, body) -> - E_loop (loop, subst id value cond, subst id value body) + | E_loop (loop, measure, cond, body) -> + E_loop (loop, subst_measure id value measure, subst id value cond, subst id value body) | E_for (id', exp1, exp2, exp3, order, body) when Id.compare id id' = 0 -> E_for (id', exp1, exp2, exp3, order, body) | E_for (id', exp1, exp2, exp3, order, body) -> @@ -1503,6 +1513,11 @@ let rec subst id value (E_aux (e_aux, annot) as exp) = in wrap e_aux +and subst_measure id value (Measure_aux (m_aux, l)) = + match m_aux with + | Measure_none -> Measure_aux (Measure_none, l) + | Measure_some exp -> Measure_aux (Measure_some (subst id value exp), l) + and subst_pexp id value (Pat_aux (pexp_aux, annot)) = let pexp_aux = match pexp_aux with | Pat_exp (pat, exp) when IdSet.mem id (pat_ids pat) -> Pat_exp (pat, exp) @@ -1660,7 +1675,7 @@ let rec locate : 'a. (l -> l) -> 'a exp -> 'a exp = fun f (E_aux (e_aux, (l, ann | E_app_infix (exp1, op, exp2) -> E_app_infix (locate f exp1, locate_id f op, locate f exp2) | E_tuple exps -> E_tuple (List.map (locate f) exps) | E_if (cond_exp, then_exp, else_exp) -> E_if (locate f cond_exp, locate f then_exp, locate f else_exp) - | E_loop (loop, cond, body) -> E_loop (loop, locate f cond, locate f body) + | E_loop (loop, measure, cond, body) -> E_loop (loop, locate_measure f measure, locate f cond, locate f body) | E_for (id, exp1, exp2, exp3, ord, exp4) -> E_for (locate_id f id, locate f exp1, locate f exp2, locate f exp3, ord, locate f exp4) | E_vector exps -> E_vector (List.map (locate f) exps) @@ -1694,6 +1709,12 @@ let rec locate : 'a. (l -> l) -> 'a exp -> 'a exp = fun f (E_aux (e_aux, (l, ann in E_aux (e_aux, (f l, annot)) +and locate_measure : 'a. (l -> l) -> 'a internal_loop_measure -> 'a internal_loop_measure = fun f (Measure_aux (m, l)) -> + let m = match m with + | Measure_none -> Measure_none + | Measure_some exp -> Measure_some (locate f exp) + in Measure_aux (m, f l) + and locate_letbind : 'a. (l -> l) -> 'a letbind -> 'a letbind = fun f (LB_aux (LB_val (pat, exp), (l, annot))) -> LB_aux (LB_val (locate_pat f pat, locate f exp), (f l, annot)) diff --git a/src/constant_propagation.ml b/src/constant_propagation.ml index 5c99a534..ee5678a8 100644 --- a/src/constant_propagation.ml +++ b/src/constant_propagation.ml @@ -440,11 +440,17 @@ let const_props defs ref_vars = let assigns = isubst_minus_set assigns (assigned_vars e4) in let e4',_ = const_prop_exp (Bindings.remove id (fst substs),snd substs) assigns e4 in re (E_for (id,e1',e2',e3',ord,e4')) assigns - | E_loop (loop,e1,e2) -> + | E_loop (loop,m,e1,e2) -> let assigns = isubst_minus_set assigns (IdSet.union (assigned_vars e1) (assigned_vars e2)) in + let m' = match m with + | Measure_aux (Measure_none,_) -> m + | Measure_aux (Measure_some exp,l) -> + let exp',_ = const_prop_exp substs assigns exp in + Measure_aux (Measure_some exp',l) + in let e1',_ = const_prop_exp substs assigns e1 in let e2',_ = const_prop_exp substs assigns e2 in - re (E_loop (loop,e1',e2')) assigns + re (E_loop (loop,m',e1',e2')) assigns | E_vector es -> let es',assigns = non_det_exp_list es in begin diff --git a/src/initial_check.ml b/src/initial_check.ml index df9af97f..8129fd89 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -367,8 +367,8 @@ and to_ast_exp ctx (P.E_aux(exp,l) : P.exp) = | P.E_for(id,e1,e2,e3,atyp,e4) -> E_for(to_ast_id id,to_ast_exp ctx e1, to_ast_exp ctx e2, to_ast_exp ctx e3,to_ast_order ctx atyp, to_ast_exp ctx e4) - | P.E_loop (P.While, e1, e2) -> E_loop (While, to_ast_exp ctx e1, to_ast_exp ctx e2) - | P.E_loop (P.Until, e1, e2) -> E_loop (Until, to_ast_exp ctx e1, to_ast_exp ctx e2) + | P.E_loop (P.While, m, e1, e2) -> E_loop (While, to_ast_measure ctx m, to_ast_exp ctx e1, to_ast_exp ctx e2) + | P.E_loop (P.Until, m, e1, e2) -> E_loop (Until, to_ast_measure ctx m, to_ast_exp ctx e1, to_ast_exp ctx e2) | P.E_vector(exps) -> E_vector(List.map (to_ast_exp ctx) exps) | P.E_vector_access(vexp,exp) -> E_vector_access(to_ast_exp ctx vexp, to_ast_exp ctx exp) | P.E_vector_subrange(vex,exp1,exp2) -> @@ -414,6 +414,16 @@ and to_ast_exp ctx (P.E_aux(exp,l) : P.exp) = | _ -> raise (Reporting.err_unreachable l __POS__ "Unparsable construct in to_ast_exp") ), (l,())) +and to_ast_measure ctx (P.Measure_aux(m,l)) : unit internal_loop_measure = + let m = match m with + | P.Measure_none -> Measure_none + | P.Measure_some exp -> + if !opt_magic_hash then + Measure_some (to_ast_exp ctx exp) + else + raise (Reporting.err_general l "Internal loop termination measure found without -dmagic_hash") + in Measure_aux (m,l) + and to_ast_lexp ctx (P.E_aux(exp,l) : P.exp) : unit lexp = let lexp = match exp with | P.E_id id -> LEXP_id (to_ast_id id) @@ -762,6 +772,10 @@ let to_ast_prec = function | P.InfixL -> InfixL | P.InfixR -> InfixR +let to_ast_loop_measure ctx = function + | P.Loop (P.While, exp) -> Loop (While, to_ast_exp ctx exp) + | P.Loop (P.Until, exp) -> Loop (Until, to_ast_exp ctx exp) + let to_ast_def ctx def : unit def list ctx_out = match def with | P.DEF_overload (id, ids) -> @@ -800,6 +814,8 @@ let to_ast_def ctx def : unit def list ctx_out = [DEF_scattered sdef], ctx | P.DEF_measure (id, pat, exp) -> [DEF_measure (to_ast_id id, to_ast_pat ctx pat, to_ast_exp ctx exp)], ctx + | P.DEF_loop_measures (id, measures) -> + [DEF_loop_measures (to_ast_id id, List.map (to_ast_loop_measure ctx) measures)], ctx let rec remove_mutrec = function | [] -> [] diff --git a/src/interpreter.ml b/src/interpreter.ml index f01a3846..b198c59b 100644 --- a/src/interpreter.ml +++ b/src/interpreter.ml @@ -243,8 +243,8 @@ let rec step (E_aux (e_aux, annot) as orig_exp) = | E_if (exp, then_exp, else_exp) -> step exp >>= fun exp' -> wrap (E_if (exp', then_exp, else_exp)) - | E_loop (While, exp, body) -> wrap (E_if (exp, E_aux (E_block [body; orig_exp], annot), exp_of_value V_unit)) - | E_loop (Until, exp, body) -> wrap (E_block [body; E_aux (E_if (exp, exp_of_value V_unit, orig_exp), annot)]) + | E_loop (While, _, exp, body) -> wrap (E_if (exp, E_aux (E_block [body; orig_exp], annot), exp_of_value V_unit)) + | E_loop (Until, _, exp, body) -> wrap (E_block [body; E_aux (E_if (exp, exp_of_value V_unit, orig_exp), annot)]) | E_assert (exp, msg) when is_true exp -> wrap unit_exp | E_assert (exp, msg) when is_false exp && is_value msg -> diff --git a/src/jib/anf.ml b/src/jib/anf.ml index 0a410249..c7f86cd4 100644 --- a/src/jib/anf.ml +++ b/src/jib/anf.ml @@ -518,7 +518,7 @@ let rec anf (E_aux (e_aux, ((l, _) as exp_annot)) as exp) = raise (Reporting.err_unreachable l __POS__ ("Encountered complex l-expression " ^ string_of_lexp lexp ^ " when converting to ANF")) - | E_loop (loop_typ, cond, exp) -> + | E_loop (loop_typ, _, cond, exp) -> let acond = anf cond in let aexp = anf exp in mk_aexp (AE_loop (loop_typ, acond, aexp)) diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index 4a72ffff..f977193a 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -1264,6 +1264,7 @@ and compile_def' n total ctx = function (* Termination measures only needed for Coq, and other theorem prover output *) | DEF_measure _ -> [], ctx + | DEF_loop_measures _ -> [], ctx | DEF_internal_mutrec fundefs -> let defs = List.map (fun fdef -> DEF_fundef fdef) fundefs in diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 4d7119d7..92c58e5d 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -980,7 +980,7 @@ let split_defs all_errors splits defs = | E_tuple es -> re (E_tuple (List.map map_exp es)) | E_if (e1,e2,e3) -> re (E_if (map_exp e1, map_exp e2, map_exp e3)) | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,map_exp e1,map_exp e2,map_exp e3,ord,map_exp e4)) - | E_loop (loop,e1,e2) -> re (E_loop (loop,map_exp e1,map_exp e2)) + | E_loop (loop,m,e1,e2) -> re (E_loop (loop,m,map_exp e1,map_exp e2)) | E_vector es -> re (E_vector (List.map map_exp es)) | E_vector_access (e1,e2) -> re (E_vector_access (map_exp e1,map_exp e2)) | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (map_exp e1,map_exp e2,map_exp e3)) @@ -1115,10 +1115,14 @@ let split_defs all_errors splits defs = | DEF_internal_mutrec _ -> [d] | DEF_fundef fd -> [DEF_fundef (map_fundef fd)] - | DEF_mapdef (MD_aux (_, (l, _))) -> Reporting.unreachable l __POS__ "mappings should be gone by now" + | DEF_mapdef (MD_aux (_, (l, _))) -> + Reporting.unreachable l __POS__ "mappings should be gone by now" | DEF_val lb -> [DEF_val (map_letbind lb)] | DEF_scattered sd -> List.map (fun x -> DEF_scattered x) (map_scattered_def sd) | DEF_measure (id,pat,exp) -> [DEF_measure (id,pat,map_exp exp)] + | DEF_loop_measures (id,_) -> + Reporting.unreachable (id_loc id) __POS__ + "Loop termination measures should have been rewritten before now" in Defs (List.concat (List.map map_def defs)) in @@ -2065,7 +2069,7 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) = let d3,a3,r3 = analyse_exp fn_id env assigns e3 in let assigns = add_dep_to_assigned d1 (dep_bindings_merge a2 a3) [e2;e3] in (dmerge d1 (dmerge d2 d3), assigns, merge r1 (merge r2 r3)) - | E_loop (_,e1,e2) -> + | E_loop (_,_,e1,e2) -> (* We remove all of the variables assigned in the loop, so we don't need to add control dependencies *) let assigns = remove_assigns [e1;e2] " assigned in a loop" in diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml index c68a258d..8361d5f5 100644 --- a/src/ocaml_backend.ml +++ b/src/ocaml_backend.ml @@ -256,7 +256,7 @@ let rec ocaml_exp ctx (E_aux (exp_aux, _) as exp) = separate space [string "let"; ocaml_atomic_lexp ctx lexp; equals; string "ref"; parens (ocaml_atomic_exp ctx exp1 ^^ space ^^ colon ^^ space ^^ ocaml_typ ctx (Rewrites.simple_typ (typ_of exp1))); string "in"] ^/^ ocaml_exp ctx exp2 - | E_loop (Until, cond, body) -> + | E_loop (Until, _, cond, body) -> let loop_body = (ocaml_atomic_exp ctx body ^^ semi) ^/^ @@ -267,7 +267,7 @@ let rec ocaml_exp ctx (E_aux (exp_aux, _) as exp) = (string "let rec loop () =" ^//^ loop_body) ^/^ string "in" ^/^ string "loop ()" - | E_loop (While, cond, body) -> + | E_loop (While, _, cond, body) -> let loop_body = separate space [string "if"; ocaml_atomic_exp ctx cond; string "then"; parens (ocaml_atomic_exp ctx body ^^ semi ^^ space ^^ string "loop ()"); diff --git a/src/parse_ast.ml b/src/parse_ast.ml index 818c9340..896a860d 100644 --- a/src/parse_ast.ml +++ b/src/parse_ast.ml @@ -238,7 +238,14 @@ and fpat = type loop = While | Until -type +type measure_aux = (* optional termination measure for a loop *) + | Measure_none + | Measure_some of exp + +and measure = + | Measure_aux of measure_aux * l + +and exp_aux = (* Expression *) E_block of (exp) list (* block (parsing conflict with structs?) *) | E_nondet of (exp) list (* block that can evaluate the contained expressions in any ordering *) @@ -251,7 +258,7 @@ exp_aux = (* Expression *) | E_app_infix of exp * id * exp (* infix function application *) | E_tuple of (exp) list (* tuple *) | E_if of exp * exp * exp (* conditional *) - | E_loop of loop * exp * exp + | E_loop of loop * measure * exp * exp | E_for of id * exp * exp * exp * atyp * exp (* loop *) | E_vector of (exp) list (* vector (indexed from 0) *) | E_vector_access of exp * exp (* vector access *) @@ -489,6 +496,10 @@ dec_spec = DEC_aux of dec_spec_aux * l +type loop_measure = + | Loop of loop * exp + + type scattered_def = SD_aux of scattered_def_aux * l @@ -509,6 +520,7 @@ def = (* Top-level definition *) | DEF_default of default_typing_spec (* default kind and type assumptions *) | DEF_scattered of scattered_def (* scattered definition *) | DEF_measure of id * pat * exp (* separate termination measure declaration *) + | DEF_loop_measures of id * loop_measure list (* separate termination measure declaration *) | DEF_reg_dec of dec_spec (* register declaration *) | DEF_pragma of string * string * l | DEF_internal_mutrec of fundef list diff --git a/src/parser.mly b/src/parser.mly index 39ca75ff..5e448a05 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -92,6 +92,7 @@ let mk_typ t n m = ATyp_aux (t, loc n m) let mk_pat p n m = P_aux (p, loc n m) let mk_pexp p n m = Pat_aux (p, loc n m) let mk_exp e n m = E_aux (e, loc n m) +let mk_measure meas n m = Measure_aux (meas, loc n m) let mk_lit l n m = L_aux (l, loc n m) let mk_lit_exp l n m = mk_exp (E_lit (mk_lit l n m)) n m let mk_typschm tq t n m = TypSchm_aux (TypSchm_ts (tq, t), loc n m) @@ -748,6 +749,13 @@ exp_eof: | exp Eof { $1 } +/* Internal syntax for loop measures, rejected in normal code by initial_check */ +internal_loop_measure: + | + { mk_measure Measure_none $startpos $endpos } + | TerminationMeasure Lcurly exp Rcurly + { mk_measure (Measure_some $3) $startpos $endpos } + exp: | exp0 { $1 } @@ -802,10 +810,10 @@ exp: else ATyp_aux(ATyp_dec,loc $startpos($6) $endpos($6)) in mk_exp (E_for ($3, $5, $7, step, ord, $9)) $startpos $endpos } - | Repeat exp Until exp - { mk_exp (E_loop (Until, $4, $2)) $startpos $endpos } - | While exp Do exp - { mk_exp (E_loop (While, $2, $4)) $startpos $endpos } + | Repeat internal_loop_measure exp Until exp + { mk_exp (E_loop (Until, $2, $5, $3)) $startpos $endpos } + | While internal_loop_measure exp Do exp + { mk_exp (E_loop (While, $2, $3, $5)) $startpos $endpos } /* Debugging only, will be rejected in initial_check if debugging isn't on */ | InternalPLet pat Eq exp In exp @@ -1401,6 +1409,17 @@ scattered_clause: | Function_ Clause funcl { mk_sd (SD_funcl $3) $startpos $endpos } +loop_measure: + | Until exp + { Loop (Until, $2) } + | While exp + { Loop (While, $2) } + +loop_measures: + | loop_measure + { [$1] } + | loop_measure Comma loop_measures + { $1::$3 } def: | fun_def @@ -1439,6 +1458,8 @@ def: { DEF_pragma (fst $1, snd $1, loc $startpos $endpos) } | TerminationMeasure id pat Eq exp { DEF_measure ($2, $3, $5) } + | TerminationMeasure id loop_measures + { DEF_loop_measures ($2,$3) } defs_list: | def diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index b4f32dce..dabb7b56 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -1504,43 +1504,58 @@ let doc_exp, doc_let = | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") end - | Id_aux (Id (("while#" | "until#") as combinator), _) -> - let combinator = String.sub combinator 0 (String.length combinator - 1) in + | Id_aux (Id (("while#" | "until#" | "while#t" | "until#t") as combinator), _) -> + let combinator = String.sub combinator 0 (String.index combinator '#') in begin - match args with - | [cond; varstuple; body] -> - let return (E_aux (e, a)) = E_aux (E_internal_return (E_aux (e, a)), a) in - let csuffix, cond, body = - match effectful (effect_of cond), effectful (effect_of body) with - | false, false -> "", cond, body - | false, true -> "M", return cond, body - | true, false -> "M", cond, return body - | true, true -> "M", cond, body - in - let used_vars_body = find_e_ids body in - let lambda = - (* Work around indentation issues in Lem when translating - tuple or literal unit patterns to Isabelle *) - match fst (uncast_exp varstuple) with - | E_aux (E_tuple _, _) - when not (IdSet.mem (mk_id "varstup") used_vars_body)-> - separate space [string "fun varstup"; bigarrow] ^^ break 1 ^^ - separate space [string "let"; squote ^^ expY varstuple; string ":= varstup in"] - | E_aux (E_lit (L_aux (L_unit, _)), _) - when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> - separate space [string "fun unit_var"; bigarrow] - | _ -> - separate space [string "fun"; expY varstuple; bigarrow] - in - parens ( - (prefix 2 1) - ((separate space) [string (combinator ^ csuffix); expY varstuple]) - ((prefix 0 1) - (parens (prefix 2 1 (group lambda) (expN cond))) - (parens (prefix 2 1 (group lambda) (expN body)))) - ) - | _ -> raise (Reporting.err_unreachable l __POS__ - "Unexpected number of arguments for loop combinator") + let cond, varstuple, body, measure = + match args with + | [cond; varstuple; body] -> cond, varstuple, body, None + | [cond; varstuple; body; measure] -> cond, varstuple, body, Some measure + | _ -> raise (Reporting.err_unreachable l __POS__ + "Unexpected number of arguments for loop combinator") + in + let return (E_aux (e, (l,a))) = + let a' = mk_tannot (env_of_annot (l,a)) bool_typ no_effect in + E_aux (E_internal_return (E_aux (e, (l,a))), (l,a')) + in + let simple_bool (E_aux (_, (l,a)) as exp) = + let a' = mk_tannot (env_of_annot (l,a)) bool_typ no_effect in + E_aux (E_cast (bool_typ, exp), (l,a')) + in + let csuffix, cond, body = + match effectful (effect_of cond), effectful (effect_of body) with + | false, false -> "", cond, body + | false, true -> "M", return cond, body + | true, false -> "M", simple_bool cond, return body + | true, true -> "M", simple_bool cond, body + in + let msuffix, measure_pp = + match measure with + | None -> "", [] + | Some exp -> "T", [expY exp] + in + let used_vars_body = find_e_ids body in + let lambda = + (* Work around indentation issues in Lem when translating + tuple or literal unit patterns to Isabelle *) + match fst (uncast_exp varstuple) with + | E_aux (E_tuple _, _) + when not (IdSet.mem (mk_id "varstup") used_vars_body)-> + separate space [string "fun varstup"; bigarrow] ^^ break 1 ^^ + separate space [string "let"; squote ^^ expY varstuple; string ":= varstup in"] + | E_aux (E_lit (L_aux (L_unit, _)), _) + when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> + separate space [string "fun unit_var"; bigarrow] + | _ -> + separate space [string "fun"; expY varstuple; bigarrow] + in + parens ( + (prefix 2 1) + ((separate space) (string (combinator ^ csuffix ^ msuffix)::measure_pp@[expY varstuple])) + ((prefix 0 1) + (parens (prefix 2 1 (group lambda) (expN cond))) + (parens (prefix 2 1 (group lambda) (expN body)))) + ) end | Id_aux (Id "early_return", _) -> begin @@ -2191,7 +2206,9 @@ let types_used_with_generic_eq defs = | DEF_mapdef (MD_aux (_,(l,_))) | DEF_scattered (SD_aux (_,(l,_))) | DEF_measure (Id_aux (_,l),_,_) - -> unreachable l __POS__ "Internal definition found in the Coq back-end" + | DEF_loop_measures (Id_aux (_,l),_) + -> unreachable l __POS__ + "Definition found in the Coq back-end that should have been rewritten away" | DEF_internal_mutrec fds -> List.fold_left IdSet.union IdSet.empty (List.map typs_req_fundef fds) | DEF_val lb -> @@ -2925,6 +2942,10 @@ let rec doc_def unimplemented generic_eq_types def = | DEF_measure (id,_,_) -> unreachable (id_loc id) __POS__ ("Termination measure for " ^ string_of_id id ^ " should have been rewritten before backend") + | DEF_loop_measures (id,_) -> + unreachable (id_loc id) __POS__ + ("Loop termination measures for " ^ string_of_id id ^ + " should have been rewritten before backend") let find_exc_typ defs = let is_exc_typ_def = function diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 6479a028..2daee940 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -715,11 +715,12 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") end - | Id_aux (Id (("while#" | "until#") as combinator), _) -> - let combinator = String.sub combinator 0 (String.length combinator - 1) in + | Id_aux (Id (("while#" | "until#" | "while#t" | "until#t") as combinator), _) -> + let combinator = String.sub combinator 0 (String.index combinator '#') in begin match args with - | [cond; varstuple; body] -> + | [cond; varstuple; body] + | [cond; varstuple; body; _] -> (* Ignore termination measures - not used in Lem *) let return (E_aux (e, a)) = E_aux (E_internal_return (E_aux (e, a)), a) in let csuffix, cond, body = match effectful (effect_of cond), effectful (effect_of body) with @@ -1485,6 +1486,7 @@ let rec doc_def_lem type_env def = | DEF_mapdef (MD_aux (_, (l, _))) -> unreachable l __POS__ "Lem doesn't support mappings" | DEF_pragma _ -> empty | DEF_measure _ -> empty (* we might use these in future *) + | DEF_loop_measures _ -> empty let find_exc_typ defs = let is_exc_typ_def = function diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 7f3a2b63..e46f784e 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -387,10 +387,10 @@ let rec doc_exp (E_aux (e_aux, _) as exp) = | E_list exps -> string "[|" ^^ separate_map (comma ^^ space) doc_exp exps ^^ string "|]" | E_cons (exp1, exp2) -> doc_atomic_exp exp1 ^^ space ^^ string "::" ^^ space ^^ doc_exp exp2 | E_record fexps -> separate space [string "struct"; string "{"; doc_fexps fexps; string "}"] - | E_loop (While, cond, exp) -> - separate space [string "while"; doc_exp cond; string "do"; doc_exp exp] - | E_loop (Until, cond, exp) -> - separate space [string "repeat"; doc_exp exp; string "until"; doc_exp cond] + | E_loop (While, measure, cond, exp) -> + separate space ([string "while"] @ doc_measure measure @ [doc_exp cond; string "do"; doc_exp exp]) + | E_loop (Until, measure, cond, exp) -> + separate space ([string "repeat"] @ doc_measure measure @ [doc_exp exp; string "until"; doc_exp cond]) | E_record_update (exp, fexps) -> separate space [string "{"; doc_exp exp; string "with"; doc_fexps fexps; string "}"] | E_vector_append (exp1, exp2) -> separate space [doc_atomic_exp exp1; string "@"; doc_atomic_exp exp2] @@ -429,6 +429,10 @@ let rec doc_exp (E_aux (e_aux, _) as exp) = | E_app (id, [exp]) when Id.compare (mk_id "pow2") id == 0 -> separate space [string "2"; string "^"; doc_atomic_exp exp] | _ -> doc_atomic_exp exp +and doc_measure (Measure_aux (m_aux, _)) = + match m_aux with + | Measure_none -> [] + | Measure_some exp -> [string "termination_measure"; braces (doc_exp exp)] and doc_infix n (E_aux (e_aux, _) as exp) = match e_aux with | E_app_infix (l, op, r) when n < 10 -> @@ -643,6 +647,13 @@ let doc_prec = function | InfixL -> string "infixl" | InfixR -> string "infixr" +let doc_loop_measures l = + separate_map (comma ^^ break 1) + (function (Loop (l,e)) -> + string (match l with While -> "while" | Until -> "until") ^^ + space ^^ doc_exp e) + l + let rec doc_scattered (SD_aux (sd_aux, _)) = match sd_aux with | SD_function (_, _, _, id) -> @@ -679,6 +690,8 @@ let rec doc_def def = group (match def with | DEF_measure (id,pat,exp) -> string "termination_measure" ^^ space ^^ doc_id id ^/^ doc_pat pat ^^ space ^^ equals ^/^ doc_exp exp + | DEF_loop_measures (id,measures) -> + string "termination_measure" ^^ space ^^ doc_id id ^/^ doc_loop_measures measures | DEF_pragma (pragma, arg, l) -> string ("$" ^ pragma ^ " " ^ arg) | DEF_fixity (prec, n, id) -> diff --git a/src/rewriter.ml b/src/rewriter.ml index edf0d4a5..2573a135 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -151,7 +151,7 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match destruct_tannot (snd annot) union_effects (fun_app_effects f env) (union_eff_exps [e1;e2]) | E_if (e1,e2,e3) -> union_eff_exps [e1;e2;e3] | E_for (_,e1,e2,e3,_,e4) -> union_eff_exps [e1;e2;e3;e4] - | E_loop (_,e1,e2) -> union_eff_exps [e1;e2] + | E_loop (_,_,e1,e2) -> union_eff_exps [e1;e2] | E_vector es -> union_eff_exps es | E_vector_access (e1,e2) -> union_eff_exps [e1;e2] | E_vector_subrange (e1,e2,e3) -> union_eff_exps [e1;e2;e3] @@ -280,8 +280,12 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot)) as orig_exp) = | E_if (c,t,e) -> rewrap (E_if (rewrite c,rewrite t, rewrite e)) | E_for (id, e1, e2, e3, o, body) -> rewrap (E_for (id, rewrite e1, rewrite e2, rewrite e3, o, rewrite body)) - | E_loop (loop, e1, e2) -> - rewrap (E_loop (loop, rewrite e1, rewrite e2)) + | E_loop (loop, m, e1, e2) -> + let m = match m with + | Measure_aux (Measure_none,_) -> m + | Measure_aux (Measure_some exp,l) -> Measure_aux (Measure_some (rewrite exp),l) + in + rewrap (E_loop (loop, m, rewrite e1, rewrite e2)) | E_vector exps -> rewrap (E_vector (List.map rewrite exps)) | E_vector_access (vec,index) -> rewrap (E_vector_access (rewrite vec,rewrite index)) | E_vector_subrange (vec,i1,i2) -> @@ -362,8 +366,9 @@ let rewrite_def rewriters d = match d with | DEF_internal_mutrec fdefs -> DEF_internal_mutrec (List.map (rewriters.rewrite_fun rewriters) fdefs) | DEF_val letbind -> DEF_val (rewriters.rewrite_let rewriters letbind) | DEF_pragma (pragma, arg, l) -> DEF_pragma (pragma, arg, l) - | DEF_scattered _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "DEF_scattered survived to rewritter") + | DEF_scattered _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "DEF_scattered survived to rewriter") | DEF_measure (id,pat,exp) -> DEF_measure (id,rewriters.rewrite_pat rewriters pat, rewriters.rewrite_exp rewriters exp) + | DEF_loop_measures (id,_) -> raise (Reporting.err_unreachable (id_loc id) __POS__ "DEF_loop_measures survived to rewriter") let rewrite_defs_base rewriters (Defs defs) = let rec rewrite ds = match ds with @@ -539,7 +544,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, ; e_tuple : 'exp list -> 'exp_aux ; e_if : 'exp * 'exp * 'exp -> 'exp_aux ; e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux - ; e_loop : loop * 'exp * 'exp -> 'exp_aux + ; e_loop : loop * ('exp option * Parse_ast.l) * 'exp * 'exp -> 'exp_aux ; e_vector : 'exp list -> 'exp_aux ; e_vector_access : 'exp * 'exp -> 'exp_aux ; e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux @@ -602,8 +607,12 @@ let rec fold_exp_aux alg = function | E_if (e1,e2,e3) -> alg.e_if (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) | E_for (id,e1,e2,e3,order,e4) -> alg.e_for (id,fold_exp alg e1, fold_exp alg e2, fold_exp alg e3, order, fold_exp alg e4) - | E_loop (loop_type, e1, e2) -> - alg.e_loop (loop_type, fold_exp alg e1, fold_exp alg e2) + | E_loop (loop_type, m, e1, e2) -> + let m = match m with + | Measure_aux (Measure_none,l) -> None,l + | Measure_aux (Measure_some exp,l) -> Some (fold_exp alg exp),l + in + alg.e_loop (loop_type, m, fold_exp alg e1, fold_exp alg e2) | E_vector es -> alg.e_vector (List.map (fold_exp alg) es) | E_vector_access (e1,e2) -> alg.e_vector_access (fold_exp alg e1, fold_exp alg e2) | E_vector_subrange (e1,e2,e3) -> @@ -681,7 +690,9 @@ let id_exp_alg = ; e_tuple = (fun es -> E_tuple es) ; e_if = (fun (e1,e2,e3) -> E_if (e1,e2,e3)) ; e_for = (fun (id,e1,e2,e3,order,e4) -> E_for (id,e1,e2,e3,order,e4)) - ; e_loop = (fun (lt, e1, e2) -> E_loop (lt, e1, e2)) + ; e_loop = (fun (lt, (m,l), e1, e2) -> + let m = match m with None -> Measure_none | Some e -> Measure_some e in + E_loop (lt, Measure_aux (m,l), e1, e2)) ; e_vector = (fun es -> E_vector es) ; e_vector_access = (fun (e1,e2) -> E_vector_access (e1,e2)) ; e_vector_subrange = (fun (e1,e2,e3) -> E_vector_subrange (e1,e2,e3)) @@ -776,8 +787,12 @@ let compute_exp_alg bot join = ; e_if = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_if (e1,e2,e3))) ; e_for = (fun (id,(v1,e1),(v2,e2),(v3,e3),order,(v4,e4)) -> (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4))) - ; e_loop = (fun (lt, (v1, e1), (v2, e2)) -> - (join_list [v1;v2], E_loop (lt, e1, e2))) + ; e_loop = (fun (lt, (m,l), (v1, e1), (v2, e2)) -> + let vs,m = match m with + | None -> [], Measure_none + | Some (v,e) -> [v], Measure_some e + in + (join_list (vs@[v1;v2]), E_loop (lt, Measure_aux (m,l), e1, e2))) ; e_vector = split_join (fun es -> E_vector es) ; e_vector_access = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_access (e1,e2))) ; e_vector_subrange = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_subrange (e1,e2,e3))) @@ -878,7 +893,8 @@ let pure_exp_alg bot join = ; e_tuple = join_list ; e_if = (fun (v1,v2,v3) -> join_list [v1;v2;v3]) ; e_for = (fun (id,v1,v2,v3,order,v4) -> join_list [v1;v2;v3;v4]) - ; e_loop = (fun (lt, v1, v2) -> join v1 v2) + ; e_loop = (fun (lt, (m,_), v1, v2) -> + let v = join v1 v2 in match m with None -> v | Some v' -> join v v') ; e_vector = join_list ; e_vector_access = (fun (v1,v2) -> join v1 v2) ; e_vector_subrange = (fun (v1,v2,v3) -> join_list [v1;v2;v3]) @@ -927,3 +943,200 @@ let pure_exp_alg bot join = ; lB_aux = (fun (vl,annot) -> vl) ; pat_alg = pure_pat_alg bot join } + +let default_fold_fexp f x (FE_aux (FE_Fexp (id,e),annot)) = + let x,e = f x e in + x, FE_aux (FE_Fexp (id,e),annot) + +let default_fold_pexp f x (Pat_aux (pe,ann)) = + let x,pe = match pe with + | Pat_exp (p,e) -> + let x,e = f x e in + x,Pat_exp (p,e) + | Pat_when (p,e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x,Pat_when (p,e1,e2) + in x, Pat_aux (pe,ann) + +let default_fold_letbind f x (LB_aux (LB_val (p,e),ann)) = + let x,e = f x e in + x, LB_aux (LB_val (p,e),ann) + +let rec default_fold_lexp f x (LEXP_aux (le,ann) as lexp) = + let re le = LEXP_aux (le,ann) in + match le with + | LEXP_id _ + | LEXP_cast _ + -> x, lexp + | LEXP_deref e -> + let x, e = f x e in + x, re (LEXP_deref e) + | LEXP_memory (id,es) -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (LEXP_memory (id, List.rev es)) + | LEXP_tup les -> + let x,les = List.fold_left (fun (x,les) le -> + let x,le' = default_fold_lexp f x le in x,le'::les) (x,[]) les in + x, re (LEXP_tup (List.rev les)) + | LEXP_vector_concat les -> + let x,les = List.fold_left (fun (x,les) le -> + let x,le' = default_fold_lexp f x le in x,le'::les) (x,[]) les in + x, re (LEXP_vector_concat (List.rev les)) + | LEXP_vector (le,e) -> + let x, le = default_fold_lexp f x le in + let x, e = f x e in + x, re (LEXP_vector (le,e)) + | LEXP_vector_range (le,e1,e2) -> + let x, le = default_fold_lexp f x le in + let x, e1 = f x e1 in + let x, e2 = f x e2 in + x, re (LEXP_vector_range (le,e1,e2)) + | LEXP_field (le,id) -> + let x, le = default_fold_lexp f x le in + x, re (LEXP_field (le,id)) + +let default_fold_exp f x (E_aux (e,ann) as exp) = + let re e = E_aux (e,ann) in + match e with + | E_block es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_block (List.rev es)) + | E_nondet es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_nondet (List.rev es)) + | E_id _ + | E_ref _ + | E_lit _ -> x, exp + | E_cast (typ,e) -> + let x,e = f x e in + x, re (E_cast (typ,e)) + | E_app (id,es) -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_app (id, List.rev es)) + | E_app_infix (e1,id,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_app_infix (e1,id,e2)) + | E_tuple es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_tuple (List.rev es)) + | E_if (e1,e2,e3) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + x, re (E_if (e1,e2,e3)) + | E_for (id,e1,e2,e3,order,e4) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + let x,e4 = f x e4 in + x, re (E_for (id,e1,e2,e3,order,e4)) + | E_loop (loop_type, m, e1, e2) -> + let x,m = match m with + | Measure_aux (Measure_none,_) -> x,m + | Measure_aux (Measure_some exp,l) -> + let x, exp = f x exp in + x, Measure_aux (Measure_some exp,l) + in + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_loop (loop_type, m, e1, e2)) + | E_vector es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_vector (List.rev es)) + | E_vector_access (e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_vector_access (e1,e2)) + | E_vector_subrange (e1,e2,e3) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + x, re (E_vector_subrange (e1,e2,e3)) + | E_vector_update (e1,e2,e3) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + x, re (E_vector_update (e1,e2,e3)) + | E_vector_update_subrange (e1,e2,e3,e4) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + let x,e4 = f x e4 in + x, re (E_vector_update_subrange (e1,e2,e3,e4)) + | E_vector_append (e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_vector_append (e1,e2)) + | E_list es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_list (List.rev es)) + | E_cons (e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_cons (e1,e2)) + | E_record fexps -> + let x,fexps = List.fold_left (fun (x,fes) fe -> + let x,fe' = default_fold_fexp f x fe in x,fe'::fes) (x,[]) fexps in + x, re (E_record (List.rev fexps)) + | E_record_update (e,fexps) -> + let x,e = f x e in + let x,fexps = List.fold_left (fun (x,fes) fe -> + let x,fe' = default_fold_fexp f x fe in x,fe'::fes) (x,[]) fexps in + x, re (E_record_update (e, List.rev fexps)) + | E_field (e,id) -> + let x,e = f x e in x, re (E_field (e,id)) + | E_case (e,pexps) -> + let x,e = f x e in + let x,pexps = List.fold_left (fun (x,pes) pe -> + let x,pe' = default_fold_pexp f x pe in x,pe'::pes) (x,[]) pexps in + x, re (E_case (e, List.rev pexps)) + | E_try (e,pexps) -> + let x,e = f x e in + let x,pexps = List.fold_left (fun (x,pes) pe -> + let x,pe' = default_fold_pexp f x pe in x,pe'::pes) (x,[]) pexps in + x, re (E_try (e, List.rev pexps)) + | E_let (letbind,e) -> + let x,letbind = default_fold_letbind f x letbind in + let x,e = f x e in + x, re (E_let (letbind,e)) + | E_assign (lexp,e) -> + let x,lexp = default_fold_lexp f x lexp in + let x,e = f x e in + x, re (E_assign (lexp,e)) + | E_sizeof _ + | E_constraint _ + -> x,exp + | E_exit e -> + let x,e = f x e in x, re (E_exit e) + | E_throw e -> + let x,e = f x e in x, re (E_throw e) + | E_return e -> + let x,e = f x e in x, re (E_return e) + | E_assert(e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_assert (e1,e2)) + | E_var (lexp,e1,e2) -> + let x,lexp = default_fold_lexp f x lexp in + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_var (lexp,e1,e2)) + | E_internal_plet (pat,e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_internal_plet (pat,e1,e2)) + | E_internal_return e -> + let x,e = f x e in x, re (E_internal_return e) + | E_internal_value _ -> x,exp + +let rec foldin_exp f x e = f (default_fold_exp (foldin_exp f)) x e +let rec foldin_pexp f x e = default_fold_pexp (foldin_exp f) x e diff --git a/src/rewriter.mli b/src/rewriter.mli index ab29d1d9..878e0d15 100644 --- a/src/rewriter.mli +++ b/src/rewriter.mli @@ -128,7 +128,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, ; e_tuple : 'exp list -> 'exp_aux ; e_if : 'exp * 'exp * 'exp -> 'exp_aux ; e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux - ; e_loop : loop * 'exp * 'exp -> 'exp_aux + ; e_loop : loop * ('exp option * Parse_ast.l) * 'exp * 'exp -> 'exp_aux ; e_vector : 'exp list -> 'exp_aux ; e_vector_access : 'exp * 'exp -> 'exp_aux ; e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux @@ -254,3 +254,7 @@ val fix_eff_pexp : tannot pexp -> tannot pexp val fix_eff_fexp : tannot fexp -> tannot fexp val fix_eff_opt_default : tannot opt_default -> tannot opt_default + +(* In-order fold over expressions *) +val foldin_exp : (('a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b exp -> 'a * 'b exp +val foldin_pexp : (('a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b pexp -> 'a * 'b pexp diff --git a/src/rewrites.ml b/src/rewrites.ml index e148cee4..39920e33 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2369,10 +2369,15 @@ let rewrite_defs_letbind_effects env = n_exp_name by (fun by -> let body = n_exp_term (effectful body) body in k (rewrap (E_for (id,start,stop,by,dir,body)))))) - | E_loop (loop, cond, body) -> + | E_loop (loop, measure, cond, body) -> + let measure = match measure with + | Measure_aux (Measure_none,_) -> measure + | Measure_aux (Measure_some exp,l) -> + Measure_aux (Measure_some (n_exp_term false exp),l) + in let cond = n_exp_term (effectful cond) cond in let body = n_exp_term (effectful body) body in - k (rewrap (E_loop (loop,cond,body))) + k (rewrap (E_loop (loop,measure,cond,body))) | E_vector exps -> n_exp_nameL exps (fun exps -> k (rewrap (E_vector exps))) @@ -3448,7 +3453,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = el env (typ_of exp4)))) el env (typ_of exp4)) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) - | E_loop(loop,cond,body) -> + | E_loop(loop,Measure_aux (measure,_),cond,body) -> (* Find variables that might be updated in the loop body and are used either after or within the loop. *) let vars, varpats = @@ -3458,11 +3463,14 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = in let body = rewrite_var_updates (add_vars overwrite body vars) in let (E_aux (_,(_,bannot))) = body in - let fname = match loop with - | While -> "while#" - | Until -> "until#" in + let fname, measure = match loop, measure with + | While, Measure_none -> "while#", [] + | Until, Measure_none -> "until#", [] + | While, Measure_some exp -> "while#t", [exp] + | Until, Measure_some exp -> "until#t", [exp] + in let funcl = Id_aux (Id fname,gen_loc el) in - let v = E_aux (E_app (funcl,[cond;tuple_exp vars;body]), (gen_loc el, bannot)) in + let v = E_aux (E_app (funcl,[cond;tuple_exp vars;body]@measure), (gen_loc el, bannot)) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_if (c,e1,e2) -> let vars, varpats = @@ -3763,8 +3771,10 @@ let rewrite_defs_remove_e_assign env (Defs defs) = let (Defs loop_specs) = fst (Type_error.check Env.empty (Defs (List.map gen_vs [("foreach#", "forall ('vars : Type). (int, int, int, bool, 'vars, 'vars) -> 'vars"); ("while#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); - ("until#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars")]))) in - let rewrite_exp _ e = + ("until#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); + ("while#t", "forall ('vars : Type). (bool, 'vars, 'vars, int) -> 'vars"); + ("until#t", "forall ('vars : Type). (bool, 'vars, 'vars, int) -> 'vars")]))) in + let rewrite_exp _ e = replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in rewrite_defs_base { rewrite_exp = rewrite_exp @@ -4438,6 +4448,7 @@ let minimise_recursive_functions env (Defs defs) = | d -> d in Defs (List.map rewrite_def defs) +(* Move recursive function termination measures into the function definitions. *) let move_termination_measures env (Defs defs) = let scan_for id defs = let rec aux = function @@ -4632,6 +4643,59 @@ let remove_mapping_valspecs env (Defs defs) = Defs (List.filter allowed_def defs) +(* Move loop termination measures into loop AST nodes. This is used before + type checking so that we avoid the complexity of type checking separate + measures. *) +let rec move_loop_measures (Defs defs) = + let loop_measures = + List.fold_left + (fun m d -> + match d with + | DEF_loop_measures (id, measures) -> + (* Allow multiple measure definitions, concatenating them *) + Bindings.add id + (match Bindings.find_opt id m with + | None -> measures + | Some m -> m @ measures) + m + | _ -> m) Bindings.empty defs + in + let do_exp exp_rec measures (E_aux (e,ann) as exp) = + match e, measures with + | E_loop (loop, _, e1, e2), (Loop (loop',exp))::t when loop = loop' -> + let t,e1 = exp_rec t e1 in + let t,e2 = exp_rec t e2 in + t,E_aux (E_loop (loop, Measure_aux (Measure_some exp, exp_loc exp), e1, e2),ann) + | _ -> exp_rec measures exp + in + let do_funcl (m,acc) (FCL_aux (FCL_Funcl (id, pexp),ann) as fcl) = + match Bindings.find_opt id m with + | Some measures -> + let measures,pexp = foldin_pexp do_exp measures pexp in + Bindings.add id measures m, (FCL_aux (FCL_Funcl (id, pexp),ann))::acc + | None -> m, fcl::acc + in + let unused,rev_defs = + List.fold_left + (fun (m,acc) d -> + match d with + | DEF_loop_measures _ -> m, acc + | DEF_fundef (FD_aux (FD_function (r,t,e,fcls),ann)) -> + let m,rfcls = List.fold_left do_funcl (m,[]) fcls in + m, (DEF_fundef (FD_aux (FD_function (r,t,e,List.rev rfcls),ann)))::acc + | _ -> m, d::acc) + (loop_measures,[]) defs + in let () = Bindings.iter + (fun id -> function + | [] -> () + | _::_ -> + Reporting.print_err (id_loc id) "Warning" + ("unused loop measure for function " ^ string_of_id id)) + unused + in Defs (List.rev rev_defs) + + + let opt_mono_rewrites = ref false let opt_mono_complex_nexps = ref true diff --git a/src/rewrites.mli b/src/rewrites.mli index 330f10b4..e30a4206 100644 --- a/src/rewrites.mli +++ b/src/rewrites.mli @@ -63,6 +63,9 @@ val opt_dmono_continue : bool ref (* Generate a fresh id with the given prefix *) val fresh_id : string -> l -> id +(* Move loop termination measures into loop AST nodes *) +val move_loop_measures : 'a defs -> 'a defs + (* Re-write undefined to functions created by -undefined_gen flag *) val rewrite_undefined : bool -> Env.t -> tannot defs -> tannot defs diff --git a/src/sail.ml b/src/sail.ml index 9c3a3d5c..b8be79f6 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -333,6 +333,9 @@ let load_files ?check:(check=false) type_envs files = -> Parse_ast.Defs (ast_nodes@later_nodes)) parsed (Parse_ast.Defs []) in let ast = Process_file.preprocess_ast options ast in let ast = Initial_check.process_ast ~generate:(not check) ast in + (* The separate loop measures declarations would be awkward to type check, so + move them into the definitions beforehand. *) + let ast = Rewrites.move_loop_measures ast in Profile.finish "parsing" t; let t = Profile.start () in diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml index 80bff0dd..8afc985d 100644 --- a/src/spec_analysis.ml +++ b/src/spec_analysis.ml @@ -205,7 +205,9 @@ let rec fv_of_exp consider_var bound used set (E_aux (e,(_,tannot))) : (Nameset. | E_for(id,from,to_,by,_,body) -> let _,used,set = list_fv bound used set [from;to_;by] in fv_of_exp consider_var (Nameset.add (string_of_id id) bound) used set body - | E_loop(_, cond, body) -> list_fv bound used set [cond; body] + | E_loop(_, measure, cond, body) -> + let m = match measure with Measure_aux (Measure_some exp,_) -> [exp] | _ -> [] in + list_fv bound used set (m @ [cond; body]) | E_vector_access(v,i) -> list_fv bound used set [v;i] | E_vector_subrange(v,i1,i2) -> list_fv bound used set [v;i1;i2] | E_vector_update(v,i,e) -> list_fv bound used set [v;i;e] @@ -509,6 +511,8 @@ let fv_of_def consider_var consider_scatter_as_one all_defs = function ((fun (_,u,_) -> Nameset.singleton ("measure:"^i),u) (fv_of_pes consider_var mt used mt [Pat_aux(Pat_exp (pat,exp),(Unknown,Type_check.empty_tannot))])) + | DEF_loop_measures(id,_) -> + Reporting.unreachable (id_loc id) __POS__ "Loop termination measures should be rewritten before now" let group_defs consider_scatter_as_one (Ast.Defs defs) = @@ -823,7 +827,7 @@ let nexp_subst_fns substs = | E_tuple es -> re (E_tuple (List.map s_exp es)) | E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3)) | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,s_exp e1,s_exp e2,s_exp e3,ord,s_exp e4)) - | E_loop (loop,e1,e2) -> re (E_loop (loop,s_exp e1,s_exp e2)) + | E_loop (loop,m,e1,e2) -> re (E_loop (loop,s_measure m,s_exp e1,s_exp e2)) | E_vector es -> re (E_vector (List.map s_exp es)) | E_vector_access (e1,e2) -> re (E_vector_access (s_exp e1,s_exp e2)) | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (s_exp e1,s_exp e2,s_exp e3)) @@ -846,6 +850,12 @@ let nexp_subst_fns substs = | E_internal_return e -> re (E_internal_return (s_exp e)) | E_throw e -> re (E_throw (s_exp e)) | E_try (e,cases) -> re (E_try (s_exp e, List.map s_pexp cases)) + and s_measure (Measure_aux (m,l)) = + let m = match m with + | Measure_none -> m + | Measure_some exp -> Measure_some (s_exp exp) + in + Measure_aux (m,l) and s_fexp (FE_aux (FE_Fexp (id,e), (l,annot))) = FE_aux (FE_Fexp (id,s_exp e),(l,s_tannot annot)) and s_pexp = function diff --git a/src/type_check.ml b/src/type_check.ml index d5d42316..53bf02fa 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -3696,10 +3696,15 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = in try_overload ([], Env.get_overloads f env) | E_app (f, xs) -> infer_funapp l env f xs None - | E_loop (loop_type, cond, body) -> + | E_loop (loop_type, measure, cond, body) -> let checked_cond = crule check_exp env cond bool_typ in + let checked_measure = match measure with + | Measure_aux (Measure_none,l) -> Measure_aux (Measure_none,l) + | Measure_aux (Measure_some exp,l) -> + Measure_aux (Measure_some (crule check_exp env exp int_typ),l) + in let checked_body = crule check_exp (add_opt_constraint (assert_constraint env true checked_cond) env) body unit_typ in - annot_exp (E_loop (loop_type, checked_cond, checked_body)) unit_typ + annot_exp (E_loop (loop_type, checked_measure, checked_cond, checked_body)) unit_typ | E_for (v, f, t, step, ord, body) -> begin let f, t, is_dec = match ord with @@ -4372,10 +4377,18 @@ and propagate_exp_effect_aux = function let p_body = propagate_exp_effect body in E_for (v, p_f, p_t, p_step, ord, p_body), collect_effects [p_f; p_t; p_step; p_body] - | E_loop (loop_type, cond, body) -> + | E_loop (loop_type, measure, cond, body) -> let p_cond = propagate_exp_effect cond in + let () = match measure with + | Measure_aux (Measure_some exp,l) -> + let eff = effect_of (propagate_exp_effect exp) in + if (BESet.is_empty (effect_set eff) || !opt_no_effects) + then () + else typ_error (env_of exp) l ("Loop termination measure with effects " ^ string_of_effect eff) + | _ -> () + in let p_body = propagate_exp_effect body in - E_loop (loop_type, p_cond, p_body), + E_loop (loop_type, measure, p_cond, p_body), union_effects (effect_of p_cond) (effect_of p_body) | E_let (letbind, exp) -> let p_lb, eff = propagate_letbind_effect letbind in @@ -5016,6 +5029,9 @@ and check_def : 'a. Env.t -> 'a def -> (tannot def) list * Env.t = | DEF_reg_dec (DEC_aux (DEC_typ_alias (typ, id, aspec), (l, tannot))) -> cd_err () | DEF_scattered sdef -> check_scattered env sdef | DEF_measure (id, pat, exp) -> [check_termination_measure_decl env (id, pat, exp)], env + | DEF_loop_measures (id, _) -> + Reporting.unreachable (id_loc id) __POS__ + "Loop termination measures should have been rewritten before type checking" and check_defs : 'a. int -> int -> Env.t -> 'a def list -> tannot defs * Env.t = fun n total env defs -> -- cgit v1.2.3 From 4529e0acc377bed4d1bab4230f4023e4bee3ae85 Mon Sep 17 00:00:00 2001 From: Alasdair Armstrong Date: Mon, 15 Apr 2019 14:35:02 +0100 Subject: Fix: Allow zero-length vector literals --- src/jib/jib_compile.ml | 37 +++++++++++++++++++++++-------------- src/parser.mly | 2 ++ 2 files changed, 25 insertions(+), 14 deletions(-) (limited to 'src') diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index f977193a..a59f6c80 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -254,23 +254,32 @@ let rec compile_aval l ctx = function (F_id gs, ctyp), [iclear ctyp gs] - | AV_vector ([], _) -> - raise (Reporting.err_general l "Encountered empty vector literal") + | AV_vector ([], typ) -> + let vector_ctyp = ctyp_of_typ ctx typ in + begin match ctyp_of_typ ctx typ with + | CT_fbits (0, _) -> + [], (F_lit (V_bits []), vector_ctyp), [] + | _ -> + let gs = ngensym () in + [idecl vector_ctyp gs; + iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init") [(F_lit (V_int Big_int.zero), CT_fint 64)]], + (F_id gs, vector_ctyp), + [iclear vector_ctyp gs] + end (* Convert a small bitvector to a uint64_t literal. *) | AV_vector (avals, typ) when is_bitvector avals && List.length avals <= 64 -> - begin - let bitstring = F_lit (V_bits (List.map value_of_aval_bit avals)) in - let len = List.length avals in - match destruct_vector ctx.tc_env typ with - | Some (_, Ord_aux (Ord_inc, _), _) -> - [], (bitstring, CT_fbits (len, false)), [] - | Some (_, Ord_aux (Ord_dec, _), _) -> - [], (bitstring, CT_fbits (len, true)), [] - | Some _ -> - raise (Reporting.err_general l "Encountered order polymorphic bitvector literal") - | None -> - raise (Reporting.err_general l "Encountered vector literal without vector type") + let bitstring = F_lit (V_bits (List.map value_of_aval_bit avals)) in + let len = List.length avals in + begin match destruct_vector ctx.tc_env typ with + | Some (_, Ord_aux (Ord_inc, _), _) -> + [], (bitstring, CT_fbits (len, false)), [] + | Some (_, Ord_aux (Ord_dec, _), _) -> + [], (bitstring, CT_fbits (len, true)), [] + | Some _ -> + raise (Reporting.err_general l "Encountered order polymorphic bitvector literal") + | None -> + raise (Reporting.err_general l "Encountered vector literal without vector type") end (* Convert a bitvector literal that is larger than 64-bits to a diff --git a/src/parser.mly b/src/parser.mly index 5e448a05..8c4475c4 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -1064,6 +1064,8 @@ atomic_exp: { mk_exp (E_record $3) $startpos $endpos } | Lcurly exp With fexp_exp_list Rcurly { mk_exp (E_record_update ($2, $4)) $startpos $endpos } + | Lsquare Rsquare + { mk_exp (E_vector []) $startpos $endpos } | Lsquare exp_list Rsquare { mk_exp (E_vector $2) $startpos $endpos } | Lsquare exp With atomic_exp Eq exp Rsquare -- cgit v1.2.3