diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 33 | ||||
| -rw-r--r-- | src/constant_propagation.ml | 10 | ||||
| -rw-r--r-- | src/initial_check.ml | 20 | ||||
| -rw-r--r-- | src/interpreter.ml | 4 | ||||
| -rw-r--r-- | src/jib/anf.ml | 2 | ||||
| -rw-r--r-- | src/jib/jib_compile.ml | 1 | ||||
| -rw-r--r-- | src/monomorphise.ml | 10 | ||||
| -rw-r--r-- | src/ocaml_backend.ml | 4 | ||||
| -rw-r--r-- | src/parse_ast.ml | 16 | ||||
| -rw-r--r-- | src/parser.mly | 29 | ||||
| -rw-r--r-- | src/pretty_print_coq.ml | 95 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 8 | ||||
| -rw-r--r-- | src/pretty_print_sail.ml | 21 | ||||
| -rw-r--r-- | src/rewriter.ml | 235 | ||||
| -rw-r--r-- | src/rewriter.mli | 6 | ||||
| -rw-r--r-- | src/rewrites.ml | 82 | ||||
| -rw-r--r-- | src/rewrites.mli | 3 | ||||
| -rw-r--r-- | src/sail.ml | 3 | ||||
| -rw-r--r-- | src/spec_analysis.ml | 14 | ||||
| -rw-r--r-- | src/type_check.ml | 24 |
20 files changed, 525 insertions, 95 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index d0efc0de..254afccf 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -519,7 +519,7 @@ and map_exp_annot_aux f = function | E_tuple xs -> E_tuple (List.map (map_exp_annot f) xs) | E_if (cond, t, e) -> E_if (map_exp_annot f cond, map_exp_annot f t, map_exp_annot f e) | E_for (v, e1, e2, e3, o, e4) -> E_for (v, map_exp_annot f e1, map_exp_annot f e2, map_exp_annot f e3, o, map_exp_annot f e4) - | E_loop (loop_type, e1, e2) -> E_loop (loop_type, map_exp_annot f e1, map_exp_annot f e2) + | E_loop (loop_type, measure, e1, e2) -> E_loop (loop_type, map_measure_annot f measure, map_exp_annot f e1, map_exp_annot f e2) | E_vector exps -> E_vector (List.map (map_exp_annot f) exps) | E_vector_access (exp1, exp2) -> E_vector_access (map_exp_annot f exp1, map_exp_annot f exp2) | E_vector_subrange (exp1, exp2, exp3) -> E_vector_subrange (map_exp_annot f exp1, map_exp_annot f exp2, map_exp_annot f exp3) @@ -546,6 +546,10 @@ and map_exp_annot_aux f = function | E_var (lexp, exp1, exp2) -> E_var (map_lexp_annot f lexp, map_exp_annot f exp1, map_exp_annot f exp2) | E_internal_plet (pat, exp1, exp2) -> E_internal_plet (map_pat_annot f pat, map_exp_annot f exp1, map_exp_annot f exp2) | E_internal_return exp -> E_internal_return (map_exp_annot f exp) +and map_measure_annot f (Measure_aux (m, l)) = Measure_aux (map_measure_annot_aux f m, l) +and map_measure_annot_aux f = function + | Measure_none -> Measure_none + | Measure_some exp -> Measure_some (map_exp_annot f exp) and map_opt_default_annot f (Def_val_aux (df, annot)) = Def_val_aux (map_opt_default_annot_aux f df, f annot) and map_opt_default_annot_aux f = function | Def_val_empty -> Def_val_empty @@ -648,6 +652,7 @@ let def_loc = function | DEF_internal_mutrec _ -> Parse_ast.Unknown | DEF_pragma (_, _, l) -> l | DEF_measure (id, _, _) -> id_loc id + | DEF_loop_measures (id, _) -> id_loc id let string_of_id = function | Id_aux (Id v, _) -> v @@ -838,8 +843,8 @@ let rec string_of_exp (E_aux (exp, _)) = ^ " by " ^ string_of_exp u ^ " order " ^ string_of_order ord ^ ") { " ^ string_of_exp body - | E_loop (While, cond, body) -> "while " ^ string_of_exp cond ^ " do " ^ string_of_exp body - | E_loop (Until, cond, body) -> "repeat " ^ string_of_exp body ^ " until " ^ string_of_exp cond + | E_loop (While, measure, cond, body) -> "while " ^ string_of_measure measure ^ string_of_exp cond ^ " do " ^ string_of_exp body + | E_loop (Until, measure, cond, body) -> "repeat " ^ string_of_measure measure ^ string_of_exp body ^ " until " ^ string_of_exp cond | E_assert (test, msg) -> "assert(" ^ string_of_exp test ^ ", " ^ string_of_exp msg ^ ")" | E_exit exp -> "exit " ^ string_of_exp exp | E_throw exp -> "throw " ^ string_of_exp exp @@ -855,6 +860,11 @@ let rec string_of_exp (E_aux (exp, _)) = | E_nondet _ -> "NONDET" | E_internal_value _ -> "INTERNAL VALUE" +and string_of_measure (Measure_aux (m,_)) = + match m with + | Measure_none -> "" + | Measure_some exp -> "termination_measure { " ^ string_of_exp exp ^ "}" + and string_of_fexp (FE_aux (FE_Fexp (field, exp), _)) = string_of_id field ^ " = " ^ string_of_exp exp and string_of_pexp (Pat_aux (pexp, _)) = @@ -1448,8 +1458,8 @@ let rec subst id value (E_aux (e_aux, annot) as exp) = | E_if (cond, then_exp, else_exp) -> E_if (subst id value cond, subst id value then_exp, subst id value else_exp) - | E_loop (loop, cond, body) -> - E_loop (loop, subst id value cond, subst id value body) + | E_loop (loop, measure, cond, body) -> + E_loop (loop, subst_measure id value measure, subst id value cond, subst id value body) | E_for (id', exp1, exp2, exp3, order, body) when Id.compare id id' = 0 -> E_for (id', exp1, exp2, exp3, order, body) | E_for (id', exp1, exp2, exp3, order, body) -> @@ -1503,6 +1513,11 @@ let rec subst id value (E_aux (e_aux, annot) as exp) = in wrap e_aux +and subst_measure id value (Measure_aux (m_aux, l)) = + match m_aux with + | Measure_none -> Measure_aux (Measure_none, l) + | Measure_some exp -> Measure_aux (Measure_some (subst id value exp), l) + and subst_pexp id value (Pat_aux (pexp_aux, annot)) = let pexp_aux = match pexp_aux with | Pat_exp (pat, exp) when IdSet.mem id (pat_ids pat) -> Pat_exp (pat, exp) @@ -1660,7 +1675,7 @@ let rec locate : 'a. (l -> l) -> 'a exp -> 'a exp = fun f (E_aux (e_aux, (l, ann | E_app_infix (exp1, op, exp2) -> E_app_infix (locate f exp1, locate_id f op, locate f exp2) | E_tuple exps -> E_tuple (List.map (locate f) exps) | E_if (cond_exp, then_exp, else_exp) -> E_if (locate f cond_exp, locate f then_exp, locate f else_exp) - | E_loop (loop, cond, body) -> E_loop (loop, locate f cond, locate f body) + | E_loop (loop, measure, cond, body) -> E_loop (loop, locate_measure f measure, locate f cond, locate f body) | E_for (id, exp1, exp2, exp3, ord, exp4) -> E_for (locate_id f id, locate f exp1, locate f exp2, locate f exp3, ord, locate f exp4) | E_vector exps -> E_vector (List.map (locate f) exps) @@ -1694,6 +1709,12 @@ let rec locate : 'a. (l -> l) -> 'a exp -> 'a exp = fun f (E_aux (e_aux, (l, ann in E_aux (e_aux, (f l, annot)) +and locate_measure : 'a. (l -> l) -> 'a internal_loop_measure -> 'a internal_loop_measure = fun f (Measure_aux (m, l)) -> + let m = match m with + | Measure_none -> Measure_none + | Measure_some exp -> Measure_some (locate f exp) + in Measure_aux (m, f l) + and locate_letbind : 'a. (l -> l) -> 'a letbind -> 'a letbind = fun f (LB_aux (LB_val (pat, exp), (l, annot))) -> LB_aux (LB_val (locate_pat f pat, locate f exp), (f l, annot)) diff --git a/src/constant_propagation.ml b/src/constant_propagation.ml index 5c99a534..ee5678a8 100644 --- a/src/constant_propagation.ml +++ b/src/constant_propagation.ml @@ -440,11 +440,17 @@ let const_props defs ref_vars = let assigns = isubst_minus_set assigns (assigned_vars e4) in let e4',_ = const_prop_exp (Bindings.remove id (fst substs),snd substs) assigns e4 in re (E_for (id,e1',e2',e3',ord,e4')) assigns - | E_loop (loop,e1,e2) -> + | E_loop (loop,m,e1,e2) -> let assigns = isubst_minus_set assigns (IdSet.union (assigned_vars e1) (assigned_vars e2)) in + let m' = match m with + | Measure_aux (Measure_none,_) -> m + | Measure_aux (Measure_some exp,l) -> + let exp',_ = const_prop_exp substs assigns exp in + Measure_aux (Measure_some exp',l) + in let e1',_ = const_prop_exp substs assigns e1 in let e2',_ = const_prop_exp substs assigns e2 in - re (E_loop (loop,e1',e2')) assigns + re (E_loop (loop,m',e1',e2')) assigns | E_vector es -> let es',assigns = non_det_exp_list es in begin diff --git a/src/initial_check.ml b/src/initial_check.ml index df9af97f..8129fd89 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -367,8 +367,8 @@ and to_ast_exp ctx (P.E_aux(exp,l) : P.exp) = | P.E_for(id,e1,e2,e3,atyp,e4) -> E_for(to_ast_id id,to_ast_exp ctx e1, to_ast_exp ctx e2, to_ast_exp ctx e3,to_ast_order ctx atyp, to_ast_exp ctx e4) - | P.E_loop (P.While, e1, e2) -> E_loop (While, to_ast_exp ctx e1, to_ast_exp ctx e2) - | P.E_loop (P.Until, e1, e2) -> E_loop (Until, to_ast_exp ctx e1, to_ast_exp ctx e2) + | P.E_loop (P.While, m, e1, e2) -> E_loop (While, to_ast_measure ctx m, to_ast_exp ctx e1, to_ast_exp ctx e2) + | P.E_loop (P.Until, m, e1, e2) -> E_loop (Until, to_ast_measure ctx m, to_ast_exp ctx e1, to_ast_exp ctx e2) | P.E_vector(exps) -> E_vector(List.map (to_ast_exp ctx) exps) | P.E_vector_access(vexp,exp) -> E_vector_access(to_ast_exp ctx vexp, to_ast_exp ctx exp) | P.E_vector_subrange(vex,exp1,exp2) -> @@ -414,6 +414,16 @@ and to_ast_exp ctx (P.E_aux(exp,l) : P.exp) = | _ -> raise (Reporting.err_unreachable l __POS__ "Unparsable construct in to_ast_exp") ), (l,())) +and to_ast_measure ctx (P.Measure_aux(m,l)) : unit internal_loop_measure = + let m = match m with + | P.Measure_none -> Measure_none + | P.Measure_some exp -> + if !opt_magic_hash then + Measure_some (to_ast_exp ctx exp) + else + raise (Reporting.err_general l "Internal loop termination measure found without -dmagic_hash") + in Measure_aux (m,l) + and to_ast_lexp ctx (P.E_aux(exp,l) : P.exp) : unit lexp = let lexp = match exp with | P.E_id id -> LEXP_id (to_ast_id id) @@ -762,6 +772,10 @@ let to_ast_prec = function | P.InfixL -> InfixL | P.InfixR -> InfixR +let to_ast_loop_measure ctx = function + | P.Loop (P.While, exp) -> Loop (While, to_ast_exp ctx exp) + | P.Loop (P.Until, exp) -> Loop (Until, to_ast_exp ctx exp) + let to_ast_def ctx def : unit def list ctx_out = match def with | P.DEF_overload (id, ids) -> @@ -800,6 +814,8 @@ let to_ast_def ctx def : unit def list ctx_out = [DEF_scattered sdef], ctx | P.DEF_measure (id, pat, exp) -> [DEF_measure (to_ast_id id, to_ast_pat ctx pat, to_ast_exp ctx exp)], ctx + | P.DEF_loop_measures (id, measures) -> + [DEF_loop_measures (to_ast_id id, List.map (to_ast_loop_measure ctx) measures)], ctx let rec remove_mutrec = function | [] -> [] diff --git a/src/interpreter.ml b/src/interpreter.ml index f01a3846..b198c59b 100644 --- a/src/interpreter.ml +++ b/src/interpreter.ml @@ -243,8 +243,8 @@ let rec step (E_aux (e_aux, annot) as orig_exp) = | E_if (exp, then_exp, else_exp) -> step exp >>= fun exp' -> wrap (E_if (exp', then_exp, else_exp)) - | E_loop (While, exp, body) -> wrap (E_if (exp, E_aux (E_block [body; orig_exp], annot), exp_of_value V_unit)) - | E_loop (Until, exp, body) -> wrap (E_block [body; E_aux (E_if (exp, exp_of_value V_unit, orig_exp), annot)]) + | E_loop (While, _, exp, body) -> wrap (E_if (exp, E_aux (E_block [body; orig_exp], annot), exp_of_value V_unit)) + | E_loop (Until, _, exp, body) -> wrap (E_block [body; E_aux (E_if (exp, exp_of_value V_unit, orig_exp), annot)]) | E_assert (exp, msg) when is_true exp -> wrap unit_exp | E_assert (exp, msg) when is_false exp && is_value msg -> diff --git a/src/jib/anf.ml b/src/jib/anf.ml index 0a410249..c7f86cd4 100644 --- a/src/jib/anf.ml +++ b/src/jib/anf.ml @@ -518,7 +518,7 @@ let rec anf (E_aux (e_aux, ((l, _) as exp_annot)) as exp) = raise (Reporting.err_unreachable l __POS__ ("Encountered complex l-expression " ^ string_of_lexp lexp ^ " when converting to ANF")) - | E_loop (loop_typ, cond, exp) -> + | E_loop (loop_typ, _, cond, exp) -> let acond = anf cond in let aexp = anf exp in mk_aexp (AE_loop (loop_typ, acond, aexp)) diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index 4a72ffff..f977193a 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -1264,6 +1264,7 @@ and compile_def' n total ctx = function (* Termination measures only needed for Coq, and other theorem prover output *) | DEF_measure _ -> [], ctx + | DEF_loop_measures _ -> [], ctx | DEF_internal_mutrec fundefs -> let defs = List.map (fun fdef -> DEF_fundef fdef) fundefs in diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 4d7119d7..92c58e5d 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -980,7 +980,7 @@ let split_defs all_errors splits defs = | E_tuple es -> re (E_tuple (List.map map_exp es)) | E_if (e1,e2,e3) -> re (E_if (map_exp e1, map_exp e2, map_exp e3)) | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,map_exp e1,map_exp e2,map_exp e3,ord,map_exp e4)) - | E_loop (loop,e1,e2) -> re (E_loop (loop,map_exp e1,map_exp e2)) + | E_loop (loop,m,e1,e2) -> re (E_loop (loop,m,map_exp e1,map_exp e2)) | E_vector es -> re (E_vector (List.map map_exp es)) | E_vector_access (e1,e2) -> re (E_vector_access (map_exp e1,map_exp e2)) | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (map_exp e1,map_exp e2,map_exp e3)) @@ -1115,10 +1115,14 @@ let split_defs all_errors splits defs = | DEF_internal_mutrec _ -> [d] | DEF_fundef fd -> [DEF_fundef (map_fundef fd)] - | DEF_mapdef (MD_aux (_, (l, _))) -> Reporting.unreachable l __POS__ "mappings should be gone by now" + | DEF_mapdef (MD_aux (_, (l, _))) -> + Reporting.unreachable l __POS__ "mappings should be gone by now" | DEF_val lb -> [DEF_val (map_letbind lb)] | DEF_scattered sd -> List.map (fun x -> DEF_scattered x) (map_scattered_def sd) | DEF_measure (id,pat,exp) -> [DEF_measure (id,pat,map_exp exp)] + | DEF_loop_measures (id,_) -> + Reporting.unreachable (id_loc id) __POS__ + "Loop termination measures should have been rewritten before now" in Defs (List.concat (List.map map_def defs)) in @@ -2065,7 +2069,7 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) = let d3,a3,r3 = analyse_exp fn_id env assigns e3 in let assigns = add_dep_to_assigned d1 (dep_bindings_merge a2 a3) [e2;e3] in (dmerge d1 (dmerge d2 d3), assigns, merge r1 (merge r2 r3)) - | E_loop (_,e1,e2) -> + | E_loop (_,_,e1,e2) -> (* We remove all of the variables assigned in the loop, so we don't need to add control dependencies *) let assigns = remove_assigns [e1;e2] " assigned in a loop" in diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml index c68a258d..8361d5f5 100644 --- a/src/ocaml_backend.ml +++ b/src/ocaml_backend.ml @@ -256,7 +256,7 @@ let rec ocaml_exp ctx (E_aux (exp_aux, _) as exp) = separate space [string "let"; ocaml_atomic_lexp ctx lexp; equals; string "ref"; parens (ocaml_atomic_exp ctx exp1 ^^ space ^^ colon ^^ space ^^ ocaml_typ ctx (Rewrites.simple_typ (typ_of exp1))); string "in"] ^/^ ocaml_exp ctx exp2 - | E_loop (Until, cond, body) -> + | E_loop (Until, _, cond, body) -> let loop_body = (ocaml_atomic_exp ctx body ^^ semi) ^/^ @@ -267,7 +267,7 @@ let rec ocaml_exp ctx (E_aux (exp_aux, _) as exp) = (string "let rec loop () =" ^//^ loop_body) ^/^ string "in" ^/^ string "loop ()" - | E_loop (While, cond, body) -> + | E_loop (While, _, cond, body) -> let loop_body = separate space [string "if"; ocaml_atomic_exp ctx cond; string "then"; parens (ocaml_atomic_exp ctx body ^^ semi ^^ space ^^ string "loop ()"); diff --git a/src/parse_ast.ml b/src/parse_ast.ml index 818c9340..896a860d 100644 --- a/src/parse_ast.ml +++ b/src/parse_ast.ml @@ -238,7 +238,14 @@ and fpat = type loop = While | Until -type +type measure_aux = (* optional termination measure for a loop *) + | Measure_none + | Measure_some of exp + +and measure = + | Measure_aux of measure_aux * l + +and exp_aux = (* Expression *) E_block of (exp) list (* block (parsing conflict with structs?) *) | E_nondet of (exp) list (* block that can evaluate the contained expressions in any ordering *) @@ -251,7 +258,7 @@ exp_aux = (* Expression *) | E_app_infix of exp * id * exp (* infix function application *) | E_tuple of (exp) list (* tuple *) | E_if of exp * exp * exp (* conditional *) - | E_loop of loop * exp * exp + | E_loop of loop * measure * exp * exp | E_for of id * exp * exp * exp * atyp * exp (* loop *) | E_vector of (exp) list (* vector (indexed from 0) *) | E_vector_access of exp * exp (* vector access *) @@ -489,6 +496,10 @@ dec_spec = DEC_aux of dec_spec_aux * l +type loop_measure = + | Loop of loop * exp + + type scattered_def = SD_aux of scattered_def_aux * l @@ -509,6 +520,7 @@ def = (* Top-level definition *) | DEF_default of default_typing_spec (* default kind and type assumptions *) | DEF_scattered of scattered_def (* scattered definition *) | DEF_measure of id * pat * exp (* separate termination measure declaration *) + | DEF_loop_measures of id * loop_measure list (* separate termination measure declaration *) | DEF_reg_dec of dec_spec (* register declaration *) | DEF_pragma of string * string * l | DEF_internal_mutrec of fundef list diff --git a/src/parser.mly b/src/parser.mly index 39ca75ff..5e448a05 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -92,6 +92,7 @@ let mk_typ t n m = ATyp_aux (t, loc n m) let mk_pat p n m = P_aux (p, loc n m) let mk_pexp p n m = Pat_aux (p, loc n m) let mk_exp e n m = E_aux (e, loc n m) +let mk_measure meas n m = Measure_aux (meas, loc n m) let mk_lit l n m = L_aux (l, loc n m) let mk_lit_exp l n m = mk_exp (E_lit (mk_lit l n m)) n m let mk_typschm tq t n m = TypSchm_aux (TypSchm_ts (tq, t), loc n m) @@ -748,6 +749,13 @@ exp_eof: | exp Eof { $1 } +/* Internal syntax for loop measures, rejected in normal code by initial_check */ +internal_loop_measure: + | + { mk_measure Measure_none $startpos $endpos } + | TerminationMeasure Lcurly exp Rcurly + { mk_measure (Measure_some $3) $startpos $endpos } + exp: | exp0 { $1 } @@ -802,10 +810,10 @@ exp: else ATyp_aux(ATyp_dec,loc $startpos($6) $endpos($6)) in mk_exp (E_for ($3, $5, $7, step, ord, $9)) $startpos $endpos } - | Repeat exp Until exp - { mk_exp (E_loop (Until, $4, $2)) $startpos $endpos } - | While exp Do exp - { mk_exp (E_loop (While, $2, $4)) $startpos $endpos } + | Repeat internal_loop_measure exp Until exp + { mk_exp (E_loop (Until, $2, $5, $3)) $startpos $endpos } + | While internal_loop_measure exp Do exp + { mk_exp (E_loop (While, $2, $3, $5)) $startpos $endpos } /* Debugging only, will be rejected in initial_check if debugging isn't on */ | InternalPLet pat Eq exp In exp @@ -1401,6 +1409,17 @@ scattered_clause: | Function_ Clause funcl { mk_sd (SD_funcl $3) $startpos $endpos } +loop_measure: + | Until exp + { Loop (Until, $2) } + | While exp + { Loop (While, $2) } + +loop_measures: + | loop_measure + { [$1] } + | loop_measure Comma loop_measures + { $1::$3 } def: | fun_def @@ -1439,6 +1458,8 @@ def: { DEF_pragma (fst $1, snd $1, loc $startpos $endpos) } | TerminationMeasure id pat Eq exp { DEF_measure ($2, $3, $5) } + | TerminationMeasure id loop_measures + { DEF_loop_measures ($2,$3) } defs_list: | def diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index b4f32dce..dabb7b56 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -1504,43 +1504,58 @@ let doc_exp, doc_let = | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") end - | Id_aux (Id (("while#" | "until#") as combinator), _) -> - let combinator = String.sub combinator 0 (String.length combinator - 1) in + | Id_aux (Id (("while#" | "until#" | "while#t" | "until#t") as combinator), _) -> + let combinator = String.sub combinator 0 (String.index combinator '#') in begin - match args with - | [cond; varstuple; body] -> - let return (E_aux (e, a)) = E_aux (E_internal_return (E_aux (e, a)), a) in - let csuffix, cond, body = - match effectful (effect_of cond), effectful (effect_of body) with - | false, false -> "", cond, body - | false, true -> "M", return cond, body - | true, false -> "M", cond, return body - | true, true -> "M", cond, body - in - let used_vars_body = find_e_ids body in - let lambda = - (* Work around indentation issues in Lem when translating - tuple or literal unit patterns to Isabelle *) - match fst (uncast_exp varstuple) with - | E_aux (E_tuple _, _) - when not (IdSet.mem (mk_id "varstup") used_vars_body)-> - separate space [string "fun varstup"; bigarrow] ^^ break 1 ^^ - separate space [string "let"; squote ^^ expY varstuple; string ":= varstup in"] - | E_aux (E_lit (L_aux (L_unit, _)), _) - when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> - separate space [string "fun unit_var"; bigarrow] - | _ -> - separate space [string "fun"; expY varstuple; bigarrow] - in - parens ( - (prefix 2 1) - ((separate space) [string (combinator ^ csuffix); expY varstuple]) - ((prefix 0 1) - (parens (prefix 2 1 (group lambda) (expN cond))) - (parens (prefix 2 1 (group lambda) (expN body)))) - ) - | _ -> raise (Reporting.err_unreachable l __POS__ - "Unexpected number of arguments for loop combinator") + let cond, varstuple, body, measure = + match args with + | [cond; varstuple; body] -> cond, varstuple, body, None + | [cond; varstuple; body; measure] -> cond, varstuple, body, Some measure + | _ -> raise (Reporting.err_unreachable l __POS__ + "Unexpected number of arguments for loop combinator") + in + let return (E_aux (e, (l,a))) = + let a' = mk_tannot (env_of_annot (l,a)) bool_typ no_effect in + E_aux (E_internal_return (E_aux (e, (l,a))), (l,a')) + in + let simple_bool (E_aux (_, (l,a)) as exp) = + let a' = mk_tannot (env_of_annot (l,a)) bool_typ no_effect in + E_aux (E_cast (bool_typ, exp), (l,a')) + in + let csuffix, cond, body = + match effectful (effect_of cond), effectful (effect_of body) with + | false, false -> "", cond, body + | false, true -> "M", return cond, body + | true, false -> "M", simple_bool cond, return body + | true, true -> "M", simple_bool cond, body + in + let msuffix, measure_pp = + match measure with + | None -> "", [] + | Some exp -> "T", [expY exp] + in + let used_vars_body = find_e_ids body in + let lambda = + (* Work around indentation issues in Lem when translating + tuple or literal unit patterns to Isabelle *) + match fst (uncast_exp varstuple) with + | E_aux (E_tuple _, _) + when not (IdSet.mem (mk_id "varstup") used_vars_body)-> + separate space [string "fun varstup"; bigarrow] ^^ break 1 ^^ + separate space [string "let"; squote ^^ expY varstuple; string ":= varstup in"] + | E_aux (E_lit (L_aux (L_unit, _)), _) + when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> + separate space [string "fun unit_var"; bigarrow] + | _ -> + separate space [string "fun"; expY varstuple; bigarrow] + in + parens ( + (prefix 2 1) + ((separate space) (string (combinator ^ csuffix ^ msuffix)::measure_pp@[expY varstuple])) + ((prefix 0 1) + (parens (prefix 2 1 (group lambda) (expN cond))) + (parens (prefix 2 1 (group lambda) (expN body)))) + ) end | Id_aux (Id "early_return", _) -> begin @@ -2191,7 +2206,9 @@ let types_used_with_generic_eq defs = | DEF_mapdef (MD_aux (_,(l,_))) | DEF_scattered (SD_aux (_,(l,_))) | DEF_measure (Id_aux (_,l),_,_) - -> unreachable l __POS__ "Internal definition found in the Coq back-end" + | DEF_loop_measures (Id_aux (_,l),_) + -> unreachable l __POS__ + "Definition found in the Coq back-end that should have been rewritten away" | DEF_internal_mutrec fds -> List.fold_left IdSet.union IdSet.empty (List.map typs_req_fundef fds) | DEF_val lb -> @@ -2925,6 +2942,10 @@ let rec doc_def unimplemented generic_eq_types def = | DEF_measure (id,_,_) -> unreachable (id_loc id) __POS__ ("Termination measure for " ^ string_of_id id ^ " should have been rewritten before backend") + | DEF_loop_measures (id,_) -> + unreachable (id_loc id) __POS__ + ("Loop termination measures for " ^ string_of_id id ^ + " should have been rewritten before backend") let find_exc_typ defs = let is_exc_typ_def = function diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 6479a028..2daee940 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -715,11 +715,12 @@ let doc_exp_lem, doc_let_lem = | _ -> raise (Reporting.err_unreachable l __POS__ "Unexpected number of arguments for loop combinator") end - | Id_aux (Id (("while#" | "until#") as combinator), _) -> - let combinator = String.sub combinator 0 (String.length combinator - 1) in + | Id_aux (Id (("while#" | "until#" | "while#t" | "until#t") as combinator), _) -> + let combinator = String.sub combinator 0 (String.index combinator '#') in begin match args with - | [cond; varstuple; body] -> + | [cond; varstuple; body] + | [cond; varstuple; body; _] -> (* Ignore termination measures - not used in Lem *) let return (E_aux (e, a)) = E_aux (E_internal_return (E_aux (e, a)), a) in let csuffix, cond, body = match effectful (effect_of cond), effectful (effect_of body) with @@ -1485,6 +1486,7 @@ let rec doc_def_lem type_env def = | DEF_mapdef (MD_aux (_, (l, _))) -> unreachable l __POS__ "Lem doesn't support mappings" | DEF_pragma _ -> empty | DEF_measure _ -> empty (* we might use these in future *) + | DEF_loop_measures _ -> empty let find_exc_typ defs = let is_exc_typ_def = function diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 7f3a2b63..e46f784e 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -387,10 +387,10 @@ let rec doc_exp (E_aux (e_aux, _) as exp) = | E_list exps -> string "[|" ^^ separate_map (comma ^^ space) doc_exp exps ^^ string "|]" | E_cons (exp1, exp2) -> doc_atomic_exp exp1 ^^ space ^^ string "::" ^^ space ^^ doc_exp exp2 | E_record fexps -> separate space [string "struct"; string "{"; doc_fexps fexps; string "}"] - | E_loop (While, cond, exp) -> - separate space [string "while"; doc_exp cond; string "do"; doc_exp exp] - | E_loop (Until, cond, exp) -> - separate space [string "repeat"; doc_exp exp; string "until"; doc_exp cond] + | E_loop (While, measure, cond, exp) -> + separate space ([string "while"] @ doc_measure measure @ [doc_exp cond; string "do"; doc_exp exp]) + | E_loop (Until, measure, cond, exp) -> + separate space ([string "repeat"] @ doc_measure measure @ [doc_exp exp; string "until"; doc_exp cond]) | E_record_update (exp, fexps) -> separate space [string "{"; doc_exp exp; string "with"; doc_fexps fexps; string "}"] | E_vector_append (exp1, exp2) -> separate space [doc_atomic_exp exp1; string "@"; doc_atomic_exp exp2] @@ -429,6 +429,10 @@ let rec doc_exp (E_aux (e_aux, _) as exp) = | E_app (id, [exp]) when Id.compare (mk_id "pow2") id == 0 -> separate space [string "2"; string "^"; doc_atomic_exp exp] | _ -> doc_atomic_exp exp +and doc_measure (Measure_aux (m_aux, _)) = + match m_aux with + | Measure_none -> [] + | Measure_some exp -> [string "termination_measure"; braces (doc_exp exp)] and doc_infix n (E_aux (e_aux, _) as exp) = match e_aux with | E_app_infix (l, op, r) when n < 10 -> @@ -643,6 +647,13 @@ let doc_prec = function | InfixL -> string "infixl" | InfixR -> string "infixr" +let doc_loop_measures l = + separate_map (comma ^^ break 1) + (function (Loop (l,e)) -> + string (match l with While -> "while" | Until -> "until") ^^ + space ^^ doc_exp e) + l + let rec doc_scattered (SD_aux (sd_aux, _)) = match sd_aux with | SD_function (_, _, _, id) -> @@ -679,6 +690,8 @@ let rec doc_def def = group (match def with | DEF_measure (id,pat,exp) -> string "termination_measure" ^^ space ^^ doc_id id ^/^ doc_pat pat ^^ space ^^ equals ^/^ doc_exp exp + | DEF_loop_measures (id,measures) -> + string "termination_measure" ^^ space ^^ doc_id id ^/^ doc_loop_measures measures | DEF_pragma (pragma, arg, l) -> string ("$" ^ pragma ^ " " ^ arg) | DEF_fixity (prec, n, id) -> diff --git a/src/rewriter.ml b/src/rewriter.ml index edf0d4a5..2573a135 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -151,7 +151,7 @@ let fix_eff_exp (E_aux (e,((l,_) as annot))) = match destruct_tannot (snd annot) union_effects (fun_app_effects f env) (union_eff_exps [e1;e2]) | E_if (e1,e2,e3) -> union_eff_exps [e1;e2;e3] | E_for (_,e1,e2,e3,_,e4) -> union_eff_exps [e1;e2;e3;e4] - | E_loop (_,e1,e2) -> union_eff_exps [e1;e2] + | E_loop (_,_,e1,e2) -> union_eff_exps [e1;e2] | E_vector es -> union_eff_exps es | E_vector_access (e1,e2) -> union_eff_exps [e1;e2] | E_vector_subrange (e1,e2,e3) -> union_eff_exps [e1;e2;e3] @@ -280,8 +280,12 @@ let rewrite_exp rewriters (E_aux (exp,(l,annot)) as orig_exp) = | E_if (c,t,e) -> rewrap (E_if (rewrite c,rewrite t, rewrite e)) | E_for (id, e1, e2, e3, o, body) -> rewrap (E_for (id, rewrite e1, rewrite e2, rewrite e3, o, rewrite body)) - | E_loop (loop, e1, e2) -> - rewrap (E_loop (loop, rewrite e1, rewrite e2)) + | E_loop (loop, m, e1, e2) -> + let m = match m with + | Measure_aux (Measure_none,_) -> m + | Measure_aux (Measure_some exp,l) -> Measure_aux (Measure_some (rewrite exp),l) + in + rewrap (E_loop (loop, m, rewrite e1, rewrite e2)) | E_vector exps -> rewrap (E_vector (List.map rewrite exps)) | E_vector_access (vec,index) -> rewrap (E_vector_access (rewrite vec,rewrite index)) | E_vector_subrange (vec,i1,i2) -> @@ -362,8 +366,9 @@ let rewrite_def rewriters d = match d with | DEF_internal_mutrec fdefs -> DEF_internal_mutrec (List.map (rewriters.rewrite_fun rewriters) fdefs) | DEF_val letbind -> DEF_val (rewriters.rewrite_let rewriters letbind) | DEF_pragma (pragma, arg, l) -> DEF_pragma (pragma, arg, l) - | DEF_scattered _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "DEF_scattered survived to rewritter") + | DEF_scattered _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "DEF_scattered survived to rewriter") | DEF_measure (id,pat,exp) -> DEF_measure (id,rewriters.rewrite_pat rewriters pat, rewriters.rewrite_exp rewriters exp) + | DEF_loop_measures (id,_) -> raise (Reporting.err_unreachable (id_loc id) __POS__ "DEF_loop_measures survived to rewriter") let rewrite_defs_base rewriters (Defs defs) = let rec rewrite ds = match ds with @@ -539,7 +544,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, ; e_tuple : 'exp list -> 'exp_aux ; e_if : 'exp * 'exp * 'exp -> 'exp_aux ; e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux - ; e_loop : loop * 'exp * 'exp -> 'exp_aux + ; e_loop : loop * ('exp option * Parse_ast.l) * 'exp * 'exp -> 'exp_aux ; e_vector : 'exp list -> 'exp_aux ; e_vector_access : 'exp * 'exp -> 'exp_aux ; e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux @@ -602,8 +607,12 @@ let rec fold_exp_aux alg = function | E_if (e1,e2,e3) -> alg.e_if (fold_exp alg e1, fold_exp alg e2, fold_exp alg e3) | E_for (id,e1,e2,e3,order,e4) -> alg.e_for (id,fold_exp alg e1, fold_exp alg e2, fold_exp alg e3, order, fold_exp alg e4) - | E_loop (loop_type, e1, e2) -> - alg.e_loop (loop_type, fold_exp alg e1, fold_exp alg e2) + | E_loop (loop_type, m, e1, e2) -> + let m = match m with + | Measure_aux (Measure_none,l) -> None,l + | Measure_aux (Measure_some exp,l) -> Some (fold_exp alg exp),l + in + alg.e_loop (loop_type, m, fold_exp alg e1, fold_exp alg e2) | E_vector es -> alg.e_vector (List.map (fold_exp alg) es) | E_vector_access (e1,e2) -> alg.e_vector_access (fold_exp alg e1, fold_exp alg e2) | E_vector_subrange (e1,e2,e3) -> @@ -681,7 +690,9 @@ let id_exp_alg = ; e_tuple = (fun es -> E_tuple es) ; e_if = (fun (e1,e2,e3) -> E_if (e1,e2,e3)) ; e_for = (fun (id,e1,e2,e3,order,e4) -> E_for (id,e1,e2,e3,order,e4)) - ; e_loop = (fun (lt, e1, e2) -> E_loop (lt, e1, e2)) + ; e_loop = (fun (lt, (m,l), e1, e2) -> + let m = match m with None -> Measure_none | Some e -> Measure_some e in + E_loop (lt, Measure_aux (m,l), e1, e2)) ; e_vector = (fun es -> E_vector es) ; e_vector_access = (fun (e1,e2) -> E_vector_access (e1,e2)) ; e_vector_subrange = (fun (e1,e2,e3) -> E_vector_subrange (e1,e2,e3)) @@ -776,8 +787,12 @@ let compute_exp_alg bot join = ; e_if = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_if (e1,e2,e3))) ; e_for = (fun (id,(v1,e1),(v2,e2),(v3,e3),order,(v4,e4)) -> (join_list [v1;v2;v3;v4], E_for (id,e1,e2,e3,order,e4))) - ; e_loop = (fun (lt, (v1, e1), (v2, e2)) -> - (join_list [v1;v2], E_loop (lt, e1, e2))) + ; e_loop = (fun (lt, (m,l), (v1, e1), (v2, e2)) -> + let vs,m = match m with + | None -> [], Measure_none + | Some (v,e) -> [v], Measure_some e + in + (join_list (vs@[v1;v2]), E_loop (lt, Measure_aux (m,l), e1, e2))) ; e_vector = split_join (fun es -> E_vector es) ; e_vector_access = (fun ((v1,e1),(v2,e2)) -> (join v1 v2, E_vector_access (e1,e2))) ; e_vector_subrange = (fun ((v1,e1),(v2,e2),(v3,e3)) -> (join_list [v1;v2;v3], E_vector_subrange (e1,e2,e3))) @@ -878,7 +893,8 @@ let pure_exp_alg bot join = ; e_tuple = join_list ; e_if = (fun (v1,v2,v3) -> join_list [v1;v2;v3]) ; e_for = (fun (id,v1,v2,v3,order,v4) -> join_list [v1;v2;v3;v4]) - ; e_loop = (fun (lt, v1, v2) -> join v1 v2) + ; e_loop = (fun (lt, (m,_), v1, v2) -> + let v = join v1 v2 in match m with None -> v | Some v' -> join v v') ; e_vector = join_list ; e_vector_access = (fun (v1,v2) -> join v1 v2) ; e_vector_subrange = (fun (v1,v2,v3) -> join_list [v1;v2;v3]) @@ -927,3 +943,200 @@ let pure_exp_alg bot join = ; lB_aux = (fun (vl,annot) -> vl) ; pat_alg = pure_pat_alg bot join } + +let default_fold_fexp f x (FE_aux (FE_Fexp (id,e),annot)) = + let x,e = f x e in + x, FE_aux (FE_Fexp (id,e),annot) + +let default_fold_pexp f x (Pat_aux (pe,ann)) = + let x,pe = match pe with + | Pat_exp (p,e) -> + let x,e = f x e in + x,Pat_exp (p,e) + | Pat_when (p,e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x,Pat_when (p,e1,e2) + in x, Pat_aux (pe,ann) + +let default_fold_letbind f x (LB_aux (LB_val (p,e),ann)) = + let x,e = f x e in + x, LB_aux (LB_val (p,e),ann) + +let rec default_fold_lexp f x (LEXP_aux (le,ann) as lexp) = + let re le = LEXP_aux (le,ann) in + match le with + | LEXP_id _ + | LEXP_cast _ + -> x, lexp + | LEXP_deref e -> + let x, e = f x e in + x, re (LEXP_deref e) + | LEXP_memory (id,es) -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (LEXP_memory (id, List.rev es)) + | LEXP_tup les -> + let x,les = List.fold_left (fun (x,les) le -> + let x,le' = default_fold_lexp f x le in x,le'::les) (x,[]) les in + x, re (LEXP_tup (List.rev les)) + | LEXP_vector_concat les -> + let x,les = List.fold_left (fun (x,les) le -> + let x,le' = default_fold_lexp f x le in x,le'::les) (x,[]) les in + x, re (LEXP_vector_concat (List.rev les)) + | LEXP_vector (le,e) -> + let x, le = default_fold_lexp f x le in + let x, e = f x e in + x, re (LEXP_vector (le,e)) + | LEXP_vector_range (le,e1,e2) -> + let x, le = default_fold_lexp f x le in + let x, e1 = f x e1 in + let x, e2 = f x e2 in + x, re (LEXP_vector_range (le,e1,e2)) + | LEXP_field (le,id) -> + let x, le = default_fold_lexp f x le in + x, re (LEXP_field (le,id)) + +let default_fold_exp f x (E_aux (e,ann) as exp) = + let re e = E_aux (e,ann) in + match e with + | E_block es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_block (List.rev es)) + | E_nondet es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_nondet (List.rev es)) + | E_id _ + | E_ref _ + | E_lit _ -> x, exp + | E_cast (typ,e) -> + let x,e = f x e in + x, re (E_cast (typ,e)) + | E_app (id,es) -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_app (id, List.rev es)) + | E_app_infix (e1,id,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_app_infix (e1,id,e2)) + | E_tuple es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_tuple (List.rev es)) + | E_if (e1,e2,e3) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + x, re (E_if (e1,e2,e3)) + | E_for (id,e1,e2,e3,order,e4) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + let x,e4 = f x e4 in + x, re (E_for (id,e1,e2,e3,order,e4)) + | E_loop (loop_type, m, e1, e2) -> + let x,m = match m with + | Measure_aux (Measure_none,_) -> x,m + | Measure_aux (Measure_some exp,l) -> + let x, exp = f x exp in + x, Measure_aux (Measure_some exp,l) + in + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_loop (loop_type, m, e1, e2)) + | E_vector es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_vector (List.rev es)) + | E_vector_access (e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_vector_access (e1,e2)) + | E_vector_subrange (e1,e2,e3) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + x, re (E_vector_subrange (e1,e2,e3)) + | E_vector_update (e1,e2,e3) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + x, re (E_vector_update (e1,e2,e3)) + | E_vector_update_subrange (e1,e2,e3,e4) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + let x,e3 = f x e3 in + let x,e4 = f x e4 in + x, re (E_vector_update_subrange (e1,e2,e3,e4)) + | E_vector_append (e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_vector_append (e1,e2)) + | E_list es -> + let x,es = List.fold_left (fun (x,es) e -> + let x,e' = f x e in x,e'::es) (x,[]) es in + x, re (E_list (List.rev es)) + | E_cons (e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_cons (e1,e2)) + | E_record fexps -> + let x,fexps = List.fold_left (fun (x,fes) fe -> + let x,fe' = default_fold_fexp f x fe in x,fe'::fes) (x,[]) fexps in + x, re (E_record (List.rev fexps)) + | E_record_update (e,fexps) -> + let x,e = f x e in + let x,fexps = List.fold_left (fun (x,fes) fe -> + let x,fe' = default_fold_fexp f x fe in x,fe'::fes) (x,[]) fexps in + x, re (E_record_update (e, List.rev fexps)) + | E_field (e,id) -> + let x,e = f x e in x, re (E_field (e,id)) + | E_case (e,pexps) -> + let x,e = f x e in + let x,pexps = List.fold_left (fun (x,pes) pe -> + let x,pe' = default_fold_pexp f x pe in x,pe'::pes) (x,[]) pexps in + x, re (E_case (e, List.rev pexps)) + | E_try (e,pexps) -> + let x,e = f x e in + let x,pexps = List.fold_left (fun (x,pes) pe -> + let x,pe' = default_fold_pexp f x pe in x,pe'::pes) (x,[]) pexps in + x, re (E_try (e, List.rev pexps)) + | E_let (letbind,e) -> + let x,letbind = default_fold_letbind f x letbind in + let x,e = f x e in + x, re (E_let (letbind,e)) + | E_assign (lexp,e) -> + let x,lexp = default_fold_lexp f x lexp in + let x,e = f x e in + x, re (E_assign (lexp,e)) + | E_sizeof _ + | E_constraint _ + -> x,exp + | E_exit e -> + let x,e = f x e in x, re (E_exit e) + | E_throw e -> + let x,e = f x e in x, re (E_throw e) + | E_return e -> + let x,e = f x e in x, re (E_return e) + | E_assert(e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_assert (e1,e2)) + | E_var (lexp,e1,e2) -> + let x,lexp = default_fold_lexp f x lexp in + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_var (lexp,e1,e2)) + | E_internal_plet (pat,e1,e2) -> + let x,e1 = f x e1 in + let x,e2 = f x e2 in + x, re (E_internal_plet (pat,e1,e2)) + | E_internal_return e -> + let x,e = f x e in x, re (E_internal_return e) + | E_internal_value _ -> x,exp + +let rec foldin_exp f x e = f (default_fold_exp (foldin_exp f)) x e +let rec foldin_pexp f x e = default_fold_pexp (foldin_exp f) x e diff --git a/src/rewriter.mli b/src/rewriter.mli index ab29d1d9..878e0d15 100644 --- a/src/rewriter.mli +++ b/src/rewriter.mli @@ -128,7 +128,7 @@ type ('a,'exp,'exp_aux,'lexp,'lexp_aux,'fexp,'fexp_aux, ; e_tuple : 'exp list -> 'exp_aux ; e_if : 'exp * 'exp * 'exp -> 'exp_aux ; e_for : id * 'exp * 'exp * 'exp * Ast.order * 'exp -> 'exp_aux - ; e_loop : loop * 'exp * 'exp -> 'exp_aux + ; e_loop : loop * ('exp option * Parse_ast.l) * 'exp * 'exp -> 'exp_aux ; e_vector : 'exp list -> 'exp_aux ; e_vector_access : 'exp * 'exp -> 'exp_aux ; e_vector_subrange : 'exp * 'exp * 'exp -> 'exp_aux @@ -254,3 +254,7 @@ val fix_eff_pexp : tannot pexp -> tannot pexp val fix_eff_fexp : tannot fexp -> tannot fexp val fix_eff_opt_default : tannot opt_default -> tannot opt_default + +(* In-order fold over expressions *) +val foldin_exp : (('a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b exp -> 'a * 'b exp +val foldin_pexp : (('a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b exp -> 'a * 'b exp) -> 'a -> 'b pexp -> 'a * 'b pexp diff --git a/src/rewrites.ml b/src/rewrites.ml index e148cee4..39920e33 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -2369,10 +2369,15 @@ let rewrite_defs_letbind_effects env = n_exp_name by (fun by -> let body = n_exp_term (effectful body) body in k (rewrap (E_for (id,start,stop,by,dir,body)))))) - | E_loop (loop, cond, body) -> + | E_loop (loop, measure, cond, body) -> + let measure = match measure with + | Measure_aux (Measure_none,_) -> measure + | Measure_aux (Measure_some exp,l) -> + Measure_aux (Measure_some (n_exp_term false exp),l) + in let cond = n_exp_term (effectful cond) cond in let body = n_exp_term (effectful body) body in - k (rewrap (E_loop (loop,cond,body))) + k (rewrap (E_loop (loop,measure,cond,body))) | E_vector exps -> n_exp_nameL exps (fun exps -> k (rewrap (E_vector exps))) @@ -3448,7 +3453,7 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = el env (typ_of exp4)))) el env (typ_of exp4)) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) - | E_loop(loop,cond,body) -> + | E_loop(loop,Measure_aux (measure,_),cond,body) -> (* Find variables that might be updated in the loop body and are used either after or within the loop. *) let vars, varpats = @@ -3458,11 +3463,14 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = in let body = rewrite_var_updates (add_vars overwrite body vars) in let (E_aux (_,(_,bannot))) = body in - let fname = match loop with - | While -> "while#" - | Until -> "until#" in + let fname, measure = match loop, measure with + | While, Measure_none -> "while#", [] + | Until, Measure_none -> "until#", [] + | While, Measure_some exp -> "while#t", [exp] + | Until, Measure_some exp -> "until#t", [exp] + in let funcl = Id_aux (Id fname,gen_loc el) in - let v = E_aux (E_app (funcl,[cond;tuple_exp vars;body]), (gen_loc el, bannot)) in + let v = E_aux (E_app (funcl,[cond;tuple_exp vars;body]@measure), (gen_loc el, bannot)) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_if (c,e1,e2) -> let vars, varpats = @@ -3763,8 +3771,10 @@ let rewrite_defs_remove_e_assign env (Defs defs) = let (Defs loop_specs) = fst (Type_error.check Env.empty (Defs (List.map gen_vs [("foreach#", "forall ('vars : Type). (int, int, int, bool, 'vars, 'vars) -> 'vars"); ("while#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); - ("until#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars")]))) in - let rewrite_exp _ e = + ("until#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); + ("while#t", "forall ('vars : Type). (bool, 'vars, 'vars, int) -> 'vars"); + ("until#t", "forall ('vars : Type). (bool, 'vars, 'vars, int) -> 'vars")]))) in + let rewrite_exp _ e = replace_memwrite_e_assign (remove_reference_types (rewrite_var_updates e)) in rewrite_defs_base { rewrite_exp = rewrite_exp @@ -4438,6 +4448,7 @@ let minimise_recursive_functions env (Defs defs) = | d -> d in Defs (List.map rewrite_def defs) +(* Move recursive function termination measures into the function definitions. *) let move_termination_measures env (Defs defs) = let scan_for id defs = let rec aux = function @@ -4632,6 +4643,59 @@ let remove_mapping_valspecs env (Defs defs) = Defs (List.filter allowed_def defs) +(* Move loop termination measures into loop AST nodes. This is used before + type checking so that we avoid the complexity of type checking separate + measures. *) +let rec move_loop_measures (Defs defs) = + let loop_measures = + List.fold_left + (fun m d -> + match d with + | DEF_loop_measures (id, measures) -> + (* Allow multiple measure definitions, concatenating them *) + Bindings.add id + (match Bindings.find_opt id m with + | None -> measures + | Some m -> m @ measures) + m + | _ -> m) Bindings.empty defs + in + let do_exp exp_rec measures (E_aux (e,ann) as exp) = + match e, measures with + | E_loop (loop, _, e1, e2), (Loop (loop',exp))::t when loop = loop' -> + let t,e1 = exp_rec t e1 in + let t,e2 = exp_rec t e2 in + t,E_aux (E_loop (loop, Measure_aux (Measure_some exp, exp_loc exp), e1, e2),ann) + | _ -> exp_rec measures exp + in + let do_funcl (m,acc) (FCL_aux (FCL_Funcl (id, pexp),ann) as fcl) = + match Bindings.find_opt id m with + | Some measures -> + let measures,pexp = foldin_pexp do_exp measures pexp in + Bindings.add id measures m, (FCL_aux (FCL_Funcl (id, pexp),ann))::acc + | None -> m, fcl::acc + in + let unused,rev_defs = + List.fold_left + (fun (m,acc) d -> + match d with + | DEF_loop_measures _ -> m, acc + | DEF_fundef (FD_aux (FD_function (r,t,e,fcls),ann)) -> + let m,rfcls = List.fold_left do_funcl (m,[]) fcls in + m, (DEF_fundef (FD_aux (FD_function (r,t,e,List.rev rfcls),ann)))::acc + | _ -> m, d::acc) + (loop_measures,[]) defs + in let () = Bindings.iter + (fun id -> function + | [] -> () + | _::_ -> + Reporting.print_err (id_loc id) "Warning" + ("unused loop measure for function " ^ string_of_id id)) + unused + in Defs (List.rev rev_defs) + + + let opt_mono_rewrites = ref false let opt_mono_complex_nexps = ref true diff --git a/src/rewrites.mli b/src/rewrites.mli index 330f10b4..e30a4206 100644 --- a/src/rewrites.mli +++ b/src/rewrites.mli @@ -63,6 +63,9 @@ val opt_dmono_continue : bool ref (* Generate a fresh id with the given prefix *) val fresh_id : string -> l -> id +(* Move loop termination measures into loop AST nodes *) +val move_loop_measures : 'a defs -> 'a defs + (* Re-write undefined to functions created by -undefined_gen flag *) val rewrite_undefined : bool -> Env.t -> tannot defs -> tannot defs diff --git a/src/sail.ml b/src/sail.ml index 9c3a3d5c..b8be79f6 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -333,6 +333,9 @@ let load_files ?check:(check=false) type_envs files = -> Parse_ast.Defs (ast_nodes@later_nodes)) parsed (Parse_ast.Defs []) in let ast = Process_file.preprocess_ast options ast in let ast = Initial_check.process_ast ~generate:(not check) ast in + (* The separate loop measures declarations would be awkward to type check, so + move them into the definitions beforehand. *) + let ast = Rewrites.move_loop_measures ast in Profile.finish "parsing" t; let t = Profile.start () in diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml index 80bff0dd..8afc985d 100644 --- a/src/spec_analysis.ml +++ b/src/spec_analysis.ml @@ -205,7 +205,9 @@ let rec fv_of_exp consider_var bound used set (E_aux (e,(_,tannot))) : (Nameset. | E_for(id,from,to_,by,_,body) -> let _,used,set = list_fv bound used set [from;to_;by] in fv_of_exp consider_var (Nameset.add (string_of_id id) bound) used set body - | E_loop(_, cond, body) -> list_fv bound used set [cond; body] + | E_loop(_, measure, cond, body) -> + let m = match measure with Measure_aux (Measure_some exp,_) -> [exp] | _ -> [] in + list_fv bound used set (m @ [cond; body]) | E_vector_access(v,i) -> list_fv bound used set [v;i] | E_vector_subrange(v,i1,i2) -> list_fv bound used set [v;i1;i2] | E_vector_update(v,i,e) -> list_fv bound used set [v;i;e] @@ -509,6 +511,8 @@ let fv_of_def consider_var consider_scatter_as_one all_defs = function ((fun (_,u,_) -> Nameset.singleton ("measure:"^i),u) (fv_of_pes consider_var mt used mt [Pat_aux(Pat_exp (pat,exp),(Unknown,Type_check.empty_tannot))])) + | DEF_loop_measures(id,_) -> + Reporting.unreachable (id_loc id) __POS__ "Loop termination measures should be rewritten before now" let group_defs consider_scatter_as_one (Ast.Defs defs) = @@ -823,7 +827,7 @@ let nexp_subst_fns substs = | E_tuple es -> re (E_tuple (List.map s_exp es)) | E_if (e1,e2,e3) -> re (E_if (s_exp e1, s_exp e2, s_exp e3)) | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,s_exp e1,s_exp e2,s_exp e3,ord,s_exp e4)) - | E_loop (loop,e1,e2) -> re (E_loop (loop,s_exp e1,s_exp e2)) + | E_loop (loop,m,e1,e2) -> re (E_loop (loop,s_measure m,s_exp e1,s_exp e2)) | E_vector es -> re (E_vector (List.map s_exp es)) | E_vector_access (e1,e2) -> re (E_vector_access (s_exp e1,s_exp e2)) | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (s_exp e1,s_exp e2,s_exp e3)) @@ -846,6 +850,12 @@ let nexp_subst_fns substs = | E_internal_return e -> re (E_internal_return (s_exp e)) | E_throw e -> re (E_throw (s_exp e)) | E_try (e,cases) -> re (E_try (s_exp e, List.map s_pexp cases)) + and s_measure (Measure_aux (m,l)) = + let m = match m with + | Measure_none -> m + | Measure_some exp -> Measure_some (s_exp exp) + in + Measure_aux (m,l) and s_fexp (FE_aux (FE_Fexp (id,e), (l,annot))) = FE_aux (FE_Fexp (id,s_exp e),(l,s_tannot annot)) and s_pexp = function diff --git a/src/type_check.ml b/src/type_check.ml index d5d42316..53bf02fa 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -3696,10 +3696,15 @@ and infer_exp env (E_aux (exp_aux, (l, ())) as exp) = in try_overload ([], Env.get_overloads f env) | E_app (f, xs) -> infer_funapp l env f xs None - | E_loop (loop_type, cond, body) -> + | E_loop (loop_type, measure, cond, body) -> let checked_cond = crule check_exp env cond bool_typ in + let checked_measure = match measure with + | Measure_aux (Measure_none,l) -> Measure_aux (Measure_none,l) + | Measure_aux (Measure_some exp,l) -> + Measure_aux (Measure_some (crule check_exp env exp int_typ),l) + in let checked_body = crule check_exp (add_opt_constraint (assert_constraint env true checked_cond) env) body unit_typ in - annot_exp (E_loop (loop_type, checked_cond, checked_body)) unit_typ + annot_exp (E_loop (loop_type, checked_measure, checked_cond, checked_body)) unit_typ | E_for (v, f, t, step, ord, body) -> begin let f, t, is_dec = match ord with @@ -4372,10 +4377,18 @@ and propagate_exp_effect_aux = function let p_body = propagate_exp_effect body in E_for (v, p_f, p_t, p_step, ord, p_body), collect_effects [p_f; p_t; p_step; p_body] - | E_loop (loop_type, cond, body) -> + | E_loop (loop_type, measure, cond, body) -> let p_cond = propagate_exp_effect cond in + let () = match measure with + | Measure_aux (Measure_some exp,l) -> + let eff = effect_of (propagate_exp_effect exp) in + if (BESet.is_empty (effect_set eff) || !opt_no_effects) + then () + else typ_error (env_of exp) l ("Loop termination measure with effects " ^ string_of_effect eff) + | _ -> () + in let p_body = propagate_exp_effect body in - E_loop (loop_type, p_cond, p_body), + E_loop (loop_type, measure, p_cond, p_body), union_effects (effect_of p_cond) (effect_of p_body) | E_let (letbind, exp) -> let p_lb, eff = propagate_letbind_effect letbind in @@ -5016,6 +5029,9 @@ and check_def : 'a. Env.t -> 'a def -> (tannot def) list * Env.t = | DEF_reg_dec (DEC_aux (DEC_typ_alias (typ, id, aspec), (l, tannot))) -> cd_err () | DEF_scattered sdef -> check_scattered env sdef | DEF_measure (id, pat, exp) -> [check_termination_measure_decl env (id, pat, exp)], env + | DEF_loop_measures (id, _) -> + Reporting.unreachable (id_loc id) __POS__ + "Loop termination measures should have been rewritten before type checking" and check_defs : 'a. int -> int -> Env.t -> 'a def list -> tannot defs * Env.t = fun n total env defs -> |
