diff options
| author | Jon French | 2019-02-03 17:50:01 +0000 |
|---|---|---|
| committer | Jon French | 2019-02-03 17:50:01 +0000 |
| commit | ab3f3671d4dd682b2aee922d5a05e9455afd5849 (patch) | |
| tree | d951e1beac8fa0af18c71e6c33879925b2707049 /src | |
| parent | bce4ee6000254c368fc83cdf62bdcdb9374b9691 (diff) | |
| parent | 4f45f462333c5494a84886677bc78a49c84da081 (diff) | |
Merge branch 'sail2' into rmem_interpreter
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast_util.ml | 150 | ||||
| -rw-r--r-- | src/ast_util.mli | 2 | ||||
| -rw-r--r-- | src/c_backend.ml | 9 | ||||
| -rw-r--r-- | src/constraint.ml | 19 | ||||
| -rw-r--r-- | src/initial_check.ml | 2 | ||||
| -rw-r--r-- | src/latex.ml | 77 | ||||
| -rw-r--r-- | src/lexer.mll | 1 | ||||
| -rw-r--r-- | src/monomorphise.ml | 195 | ||||
| -rw-r--r-- | src/optimize.ml | 100 | ||||
| -rw-r--r-- | src/parse_ast.ml | 1 | ||||
| -rw-r--r-- | src/parser.mly | 4 | ||||
| -rw-r--r-- | src/pretty_print_coq.ml | 11 | ||||
| -rw-r--r-- | src/pretty_print_sail.ml | 3 | ||||
| -rw-r--r-- | src/process_file.ml | 45 | ||||
| -rw-r--r-- | src/process_file.mli | 4 | ||||
| -rw-r--r-- | src/rewriter.ml | 1 | ||||
| -rw-r--r-- | src/rewrites.ml | 54 | ||||
| -rw-r--r-- | src/sail.ml | 12 | ||||
| -rw-r--r-- | src/sail_lib.ml | 7 | ||||
| -rw-r--r-- | src/spec_analysis.ml | 15 | ||||
| -rw-r--r-- | src/state.ml | 25 | ||||
| -rw-r--r-- | src/type_check.ml | 41 | ||||
| -rw-r--r-- | src/type_check.mli | 8 |
23 files changed, 602 insertions, 184 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index 34dfd663..c89d30c1 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -128,7 +128,7 @@ let mk_val_spec vs_aux = let kopt_kid (KOpt_aux (KOpt_kind (_, kid), _)) = kid let kopt_kind (KOpt_aux (KOpt_kind (k, _), _)) = k - + let is_nat_kopt = function | KOpt_aux (KOpt_kind (K_aux (K_int, _), _), _) -> true | _ -> false @@ -165,7 +165,7 @@ module Kind = struct | K_type, _ -> 1 | _, K_type -> -1 | K_order, _ -> 1 | _, K_order -> -1 end - + module KOpt = struct type t = kinded_id let compare kopt1 kopt2 = @@ -984,77 +984,88 @@ let lex_ord f g x1 x2 y1 y2 = | 0 -> g y1 y2 | n -> n +let rec nc_compare (NC_aux (nc1,_)) (NC_aux (nc2,_)) = + match nc1, nc2 with + | NC_equal (n1,n2), NC_equal (n3,n4) + | NC_bounded_ge (n1,n2), NC_bounded_ge (n3,n4) + | NC_bounded_le (n1,n2), NC_bounded_le (n3,n4) + | NC_not_equal (n1,n2), NC_not_equal (n3,n4) + -> lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4 + | NC_set (k1,s1), NC_set (k2,s2) -> + lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2 + | NC_or (nc1,nc2), NC_or (nc3,nc4) + | NC_and (nc1,nc2), NC_and (nc3,nc4) + -> lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4 + | NC_app (f1,args1), NC_app (f2,args2) + -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2 + | NC_var v1, NC_var v2 + -> Kid.compare v1 v2 + | NC_true, NC_true + | NC_false, NC_false + -> 0 + | NC_equal _, _ -> -1 | _, NC_equal _ -> 1 + | NC_bounded_ge _, _ -> -1 | _, NC_bounded_ge _ -> 1 + | NC_bounded_le _, _ -> -1 | _, NC_bounded_le _ -> 1 + | NC_not_equal _, _ -> -1 | _, NC_not_equal _ -> 1 + | NC_set _, _ -> -1 | _, NC_set _ -> 1 + | NC_or _, _ -> -1 | _, NC_or _ -> 1 + | NC_and _, _ -> -1 | _, NC_and _ -> 1 + | NC_app _, _ -> -1 | _, NC_app _ -> 1 + | NC_var _, _ -> -1 | _, NC_var _ -> 1 + | NC_true, _ -> -1 | _, NC_true -> 1 + +and typ_compare (Typ_aux (t1,_)) (Typ_aux (t2,_)) = + match t1,t2 with + | Typ_internal_unknown, Typ_internal_unknown -> 0 + | Typ_id id1, Typ_id id2 -> Id.compare id1 id2 + | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 + | Typ_fn (ts1,t2,e1), Typ_fn (ts3,t4,e2) -> + (match Util.compare_list typ_compare ts1 ts3 with + | 0 -> (match typ_compare t2 t4 with + | 0 -> effect_compare e1 e2 + | n -> n) + | n -> n) + | Typ_bidir (t1,t2), Typ_bidir (t3,t4) -> + (match typ_compare t1 t3 with + | 0 -> typ_compare t2 t3 + | n -> n) + | Typ_tup ts1, Typ_tup ts2 -> Util.compare_list typ_compare ts1 ts2 + | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) -> + (match Util.compare_list KOpt.compare ks1 ks2 with + | 0 -> (match nc_compare nc1 nc2 with + | 0 -> typ_compare t1 t2 + | n -> n) + | n -> n) + | Typ_app (id1,ts1), Typ_app (id2,ts2) -> + (match Id.compare id1 id2 with + | 0 -> Util.compare_list typ_arg_compare ts1 ts2 + | n -> n) + | Typ_internal_unknown, _ -> -1 | _, Typ_internal_unknown -> 1 + | Typ_id _, _ -> -1 | _, Typ_id _ -> 1 + | Typ_var _, _ -> -1 | _, Typ_var _ -> 1 + | Typ_fn _, _ -> -1 | _, Typ_fn _ -> 1 + | Typ_bidir _, _ -> -1 | _, Typ_bidir _ -> 1 + | Typ_tup _, _ -> -1 | _, Typ_tup _ -> 1 + | Typ_exist _, _ -> -1 | _, Typ_exist _ -> 1 + +and typ_arg_compare (A_aux (ta1,_)) (A_aux (ta2,_)) = + match ta1, ta2 with + | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 + | A_typ t1, A_typ t2 -> typ_compare t1 t2 + | A_order o1, A_order o2 -> order_compare o1 o2 + | A_bool nc1, A_bool nc2 -> nc_compare nc1 nc2 + | A_nexp _, _ -> -1 | _, A_nexp _ -> 1 + | A_typ _, _ -> -1 | _, A_typ _ -> 1 + | A_order _, _ -> -1 | _, A_order _ -> 1 + module NC = struct type t = n_constraint - let rec compare (NC_aux (nc1,_)) (NC_aux (nc2,_)) = - match nc1, nc2 with - | NC_equal (n1,n2), NC_equal (n3,n4) - | NC_bounded_ge (n1,n2), NC_bounded_ge (n3,n4) - | NC_bounded_le (n1,n2), NC_bounded_le (n3,n4) - | NC_not_equal (n1,n2), NC_not_equal (n3,n4) - -> lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4 - | NC_set (k1,s1), NC_set (k2,s2) -> - lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2 - | NC_or (nc1,nc2), NC_or (nc3,nc4) - | NC_and (nc1,nc2), NC_and (nc3,nc4) - -> lex_ord compare compare nc1 nc3 nc2 nc4 - | NC_true, NC_true - | NC_false, NC_false - -> 0 - | NC_equal _, _ -> -1 | _, NC_equal _ -> 1 - | NC_bounded_ge _, _ -> -1 | _, NC_bounded_ge _ -> 1 - | NC_bounded_le _, _ -> -1 | _, NC_bounded_le _ -> 1 - | NC_not_equal _, _ -> -1 | _, NC_not_equal _ -> 1 - | NC_set _, _ -> -1 | _, NC_set _ -> 1 - | NC_or _, _ -> -1 | _, NC_or _ -> 1 - | NC_and _, _ -> -1 | _, NC_and _ -> 1 - | NC_true, _ -> -1 | _, NC_true -> 1 + let compare = nc_compare end module Typ = struct type t = typ - let rec compare (Typ_aux (t1,_)) (Typ_aux (t2,_)) = - match t1,t2 with - | Typ_internal_unknown, Typ_internal_unknown -> 0 - | Typ_id id1, Typ_id id2 -> Id.compare id1 id2 - | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 - | Typ_fn (ts1,t2,e1), Typ_fn (ts3,t4,e2) -> - (match Util.compare_list compare ts1 ts3 with - | 0 -> (match compare t2 t4 with - | 0 -> effect_compare e1 e2 - | n -> n) - | n -> n) - | Typ_bidir (t1,t2), Typ_bidir (t3,t4) -> - (match compare t1 t3 with - | 0 -> compare t2 t3 - | n -> n) - | Typ_tup ts1, Typ_tup ts2 -> Util.compare_list compare ts1 ts2 - | Typ_exist (ks1,nc1,t1), Typ_exist (ks2,nc2,t2) -> - (match Util.compare_list KOpt.compare ks1 ks2 with - | 0 -> (match NC.compare nc1 nc2 with - | 0 -> compare t1 t2 - | n -> n) - | n -> n) - | Typ_app (id1,ts1), Typ_app (id2,ts2) -> - (match Id.compare id1 id2 with - | 0 -> Util.compare_list arg_compare ts1 ts2 - | n -> n) - | Typ_internal_unknown, _ -> -1 | _, Typ_internal_unknown -> 1 - | Typ_id _, _ -> -1 | _, Typ_id _ -> 1 - | Typ_var _, _ -> -1 | _, Typ_var _ -> 1 - | Typ_fn _, _ -> -1 | _, Typ_fn _ -> 1 - | Typ_bidir _, _ -> -1 | _, Typ_bidir _ -> 1 - | Typ_tup _, _ -> -1 | _, Typ_tup _ -> 1 - | Typ_exist _, _ -> -1 | _, Typ_exist _ -> 1 - and arg_compare (A_aux (ta1,_)) (A_aux (ta2,_)) = - match ta1, ta2 with - | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 - | A_typ t1, A_typ t2 -> compare t1 t2 - | A_order o1, A_order o2 -> order_compare o1 o2 - | A_bool nc1, A_bool nc2 -> NC.compare nc1 nc2 - | A_nexp _, _ -> -1 | _, A_nexp _ -> 1 - | A_typ _, _ -> -1 | _, A_typ _ -> 1 - | A_order _, _ -> -1 | _, A_order _ -> 1 + let compare = typ_compare end module TypMap = Map.Make(Typ) @@ -1289,6 +1300,9 @@ let is_fundef id = function | DEF_fundef (FD_aux (FD_function (_, _, _, FCL_aux (FCL_Funcl (id', _), _) :: _), _)) when Id.compare id' id = 0 -> true | _ -> false +let rename_valspec id (VS_aux (VS_val_spec (typschm, _, externs, is_cast), annot)) = + VS_aux (VS_val_spec (typschm, id, externs, is_cast), annot) + let rename_funcl id (FCL_aux (FCL_Funcl (_, pexp), annot)) = FCL_aux (FCL_Funcl (id, pexp), annot) let rename_fundef id (FD_aux (FD_function (ropt, topt, eopt, funcls), annot)) = @@ -1425,7 +1439,7 @@ let locate_id f (Id_aux (name, l)) = Id_aux (name, f l) let locate_kid f (Kid_aux (name, l)) = Kid_aux (name, f l) let locate_kind f (K_aux (kind, l)) = K_aux (kind, f l) - + let locate_kinded_id f (KOpt_aux (KOpt_kind (k, kid), l)) = KOpt_aux (KOpt_kind (locate_kind f k, locate_kid f kid), f l) diff --git a/src/ast_util.mli b/src/ast_util.mli index dc9f8594..65e02d81 100644 --- a/src/ast_util.mli +++ b/src/ast_util.mli @@ -372,6 +372,8 @@ val is_valspec : id -> 'a def -> bool val is_fundef : id -> 'a def -> bool +val rename_valspec : id -> 'a val_spec -> 'a val_spec + val rename_fundef : id -> 'a fundef -> 'a fundef val split_defs : ('a def -> bool) -> 'a defs -> ('a defs * 'a def * 'a defs) option diff --git a/src/c_backend.ml b/src/c_backend.ml index 65702764..7cda4668 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -1399,9 +1399,11 @@ let compile_type_def ctx (TD_aux (type_def, _)) = CTD_struct (id, Bindings.bindings ctors), { ctx with records = Bindings.add id ctors ctx.records } - | TD_variant (id, _, _, tus, _) -> + | TD_variant (id, _, typq, tus, _) -> let compile_tu = function - | Tu_aux (Tu_ty_id (typ, id), _) -> ctyp_of_typ ctx typ, id + | Tu_aux (Tu_ty_id (typ, id), _) -> + let ctx = { ctx with local_env = add_typquant (id_loc id) typq ctx.local_env } in + ctyp_of_typ ctx typ, id in let ctus = List.fold_left (fun ctus (ctyp, id) -> Bindings.add id ctyp ctus) Bindings.empty (List.map compile_tu tus) in CTD_variant (id, Bindings.bindings ctus), @@ -1761,6 +1763,9 @@ let rec compile_def ctx = function (* Only the parser and sail pretty printer care about this. *) | DEF_fixity _ -> [], ctx + (* We just ignore any pragmas we don't want to deal with. *) + | DEF_pragma _ -> [], ctx + | DEF_internal_mutrec fundefs -> let defs = List.map (fun fdef -> DEF_fundef fdef) fundefs in List.fold_left (fun (cdefs, ctx) def -> let cdefs', ctx = compile_def ctx def in (cdefs @ cdefs', ctx)) ([], ctx) defs diff --git a/src/constraint.ml b/src/constraint.ml index 7ead0cc8..b7e3cb47 100644 --- a/src/constraint.ml +++ b/src/constraint.ml @@ -131,16 +131,17 @@ let to_smt l vars constr = | _ -> raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function") in - var_decs l vars, smt_constraint constr + var_decs l vars, smt_constraint constr, smt_var -let smtlib_of_constraints ?get_model:(get_model=false) l vars constr : string = - let variables, problem = to_smt l vars constr in +let smtlib_of_constraints ?get_model:(get_model=false) l vars constr : string * (kid -> sexpr) = + let variables, problem, var_map = to_smt l vars constr in "(push)\n" ^ variables ^ "\n" ^ pp_sexpr (sfun "define-fun" [Atom "constraint"; List []; Atom "Bool"; problem]) ^ "\n(assert constraint)\n(check-sat)" ^ (if get_model then "\n(get-model)" else "") - ^ "\n(pop)" + ^ "\n(pop)", + var_map type smt_result = Unknown | Sat | Unsat @@ -183,7 +184,7 @@ let save_digests () = let call_z3' l vars constraints : smt_result = let problems = [constraints] in - let z3_file = smtlib_of_constraints l vars constraints in + let z3_file, _ = smtlib_of_constraints l vars constraints in (* prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" z3_file); *) @@ -230,9 +231,11 @@ let call_z3 l vars constraints = result let rec solve_z3 l vars constraints var = - let z3_file = smtlib_of_constraints ~get_model:true l vars constraints in + let z3_file, smt_var = smtlib_of_constraints ~get_model:true l vars constraints in + let z3_var = pp_sexpr (smt_var var) in - (* prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" z3_file); *) + (* prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" z3_file); + prerr_endline ("Solving for " ^ z3_var); *) let rec input_all chan = try @@ -250,7 +253,7 @@ let rec solve_z3 l vars constraints var = let z3_output = String.concat " " (input_all z3_chan) in let _ = Unix.close_process_in z3_chan in Sys.remove input_file; - let regexp = {|(define-fun v|} ^ Util.zencode_string (string_of_kid var) ^ {| () Int[ ]+\([0-9]+\))|} in + let regexp = {|(define-fun |} ^ z3_var ^ {| () Int[ ]+\([0-9]+\))|} in try let _ = Str.search_forward (Str.regexp regexp) z3_output 0 in let result = Big_int.of_string (Str.matched_group 1 z3_output) in diff --git a/src/initial_check.ml b/src/initial_check.ml index 7de74a93..05d51eb2 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -743,6 +743,8 @@ let to_ast_def ctx def : unit def ctx_out = | P.DEF_scattered sdef -> let sdef, ctx = to_ast_scattered ctx sdef in DEF_scattered sdef, ctx + | P.DEF_measure (id, pat, exp) -> + DEF_measure (to_ast_id id, to_ast_pat ctx pat, to_ast_exp ctx exp), ctx let rec remove_mutrec = function | [] -> [] diff --git a/src/latex.ml b/src/latex.ml index 2f578f2c..71e0ba54 100644 --- a/src/latex.ml +++ b/src/latex.ml @@ -57,6 +57,7 @@ module StringSet = Set.Make(String);; let opt_prefix = ref "sail" let opt_directory = ref "sail_latex" +let opt_simple_val = ref true let rec unique_postfix n = if n < 0 then @@ -97,13 +98,13 @@ let rec unique_postfix n = type id_category = | Function | Val - | Overload + | Overload of int | FunclCtor of id * int | FunclNum of int | FunclApp of string + | Type -let replace_numbers str = - let replacements = +let number_replacements = [ ("0", "Zero"); ("1", "One"); ("2", "Two"); @@ -114,16 +115,28 @@ let replace_numbers str = ("7", "Seven"); ("8", "Eight"); ("9", "Nine") ] - in + +(* add to this as needed *) +let other_replacements = + [ ("_", "Underscore") ] + +let char_replace str replacements = List.fold_left (fun str (from, into) -> Str.global_replace (Str.regexp_string from) into str) str replacements +let replace_numbers str = + char_replace str number_replacements + +let replace_others str = + char_replace str other_replacements + let category_name = function | Function -> "fn" | Val -> "val" - | Overload -> "overload" + | Type -> "type" + | Overload n -> "overload" ^ unique_postfix n | FunclNum n -> "fcl" ^ unique_postfix n | FunclCtor (id, n) -> - let str = replace_numbers (Util.zencode_string (string_of_id id)) in + let str = replace_others (replace_numbers (Util.zencode_string (string_of_id id))) in "fcl" ^ String.sub str 1 (String.length str - 1) ^ unique_postfix n | FunclApp str -> "fcl" ^ str @@ -134,7 +147,8 @@ let category_name_val = function let category_name_simple = function | Function -> "fn" | Val -> "val" - | Overload -> "overload" + | Type -> "type" + | Overload n -> "overload" | FunclNum _ -> "fcl" | FunclCtor (_, _) -> "fcl" | FunclApp _ -> "fcl" @@ -162,8 +176,8 @@ let latex_id id = (* If we have any other weird symbols in the id, remove them using Util.zencode_string (removing the z prefix) *) let str = Util.zencode_string str in let str = String.sub str 1 (String.length str - 1) in - (* Latex only allows letters in identifiers, so replace all numbers *) - let str = replace_numbers str in + (* Latex only allows letters in identifiers, so replace all numbers and other characters *) + let str = replace_others (replace_numbers str) in let generated = state.generated_names |> Bindings.bindings |> List.map snd |> StringSet.of_list in @@ -290,10 +304,10 @@ let latex_loc no_loc l = let commands = ref StringSet.empty -let doc_spec_simple (VS_val_spec(ts,id,ext,is_cast)) = - Pretty_print_sail.doc_id id ^^ space - ^^ colon ^^ space - ^^ Pretty_print_sail.doc_typschm ~simple:true ts +let doc_spec_simple (VS_aux (VS_val_spec (ts, id, ext, is_cast), _)) = + Pretty_print_sail.doc_id id ^^ space + ^^ colon ^^ space + ^^ Pretty_print_sail.doc_typschm ~simple:true ts let rec latex_command cat id no_loc ((l, _) as annot) = state.this <- Some id; @@ -309,7 +323,7 @@ let rec latex_command cat id no_loc ((l, _) as annot) = output_string chan (Pretty_print_sail.to_string doc); close_out chan; - ksprintf string "\\newcommand{\\sail%s%s}{\\phantomsection%s\\saildoc%s{" (category_name cat) (latex_id id) labelling (category_name_simple cat) + ksprintf string "\\newcommand{\\%s%s%s}{\\phantomsection%s\\saildoc%s{" !opt_prefix (category_name cat) (latex_id id) labelling (category_name_simple cat) ^^ docstring l ^^ string "}{" ^^ ksprintf string "\\lstinputlisting[language=sail]{%s}}}" (Filename.concat !opt_directory code_file) @@ -381,30 +395,47 @@ let process_pragma l command = Util.warn (Printf.sprintf "Bad latex pragma %s" (Reporting.loc_to_string l)); None +let tdef_id = function + | TD_abbrev (id, _, _) -> id + | TD_record (id, _, _, _, _) -> id + | TD_variant (id, _, _, _, _) -> id + | TD_enum (id, _, _, _) -> id + | TD_bitfield (id, _, _) -> id + let defs (Defs defs) = reset_state state; + let overload_counter = ref 0 in + let valspecs = ref IdSet.empty in let fundefs = ref IdSet.empty in + let typedefs = ref IdSet.empty in let latex_def def = match def with - | DEF_overload (id, ids) -> None - (* + | DEF_overload (id, ids) -> let doc = string (Printf.sprintf "overload %s = {%s}" (string_of_id id) (Util.string_of_list ", " string_of_id ids)) in - Some (latex_command Overload id doc (id_loc id, None)) - *) + incr overload_counter; + Some (latex_command (Overload !overload_counter) id doc (id_loc id, None)) - | DEF_spec (VS_aux (VS_val_spec (_, id, _, _) as vs, annot)) as def -> + | DEF_spec (VS_aux (VS_val_spec (_, id, _, _), annot) as vs) as def -> valspecs := IdSet.add id !valspecs; - Some (latex_command Val id (doc_spec_simple vs) annot) + if !opt_simple_val then + Some (latex_command Val id (doc_spec_simple vs) annot) + else + Some (latex_command Val id (Pretty_print_sail.doc_spec ~comment:false vs) annot) | DEF_fundef (FD_aux (FD_function (_, _, _, [FCL_aux (FCL_Funcl (id, _), _)]), annot)) as def -> fundefs := IdSet.add id !fundefs; Some (latex_command Function id (Pretty_print_sail.doc_def def) annot) + | DEF_type (TD_aux (tdef, annot)) as def -> + let id = tdef_id tdef in + typedefs := IdSet.add id !typedefs; + Some (latex_command Type id (Pretty_print_sail.doc_def def) annot) + | DEF_fundef (FD_aux (FD_function (_, _, _, funcls), annot)) as def -> Some (latex_funcls def funcls) @@ -432,7 +463,7 @@ let defs (Defs defs) = identifiers then outputs the correct mangled command. *) let id_command cat ids = sprintf "\\newcommand{\\%s%s}[1]{\n " !opt_prefix (category_name cat) - ^ Util.string_of_list "%\n " (fun id -> sprintf "\\ifstrequal{#1}{%s}{\\sail%s%s}{}" (string_of_id id) (category_name cat) (latex_id id)) + ^ Util.string_of_list "%\n " (fun id -> sprintf "\\ifstrequal{#1}{%s}{\\%s%s%s}{}" (string_of_id id) !opt_prefix (category_name cat) (latex_id id)) (IdSet.elements ids) ^ "}" |> string @@ -449,5 +480,7 @@ let defs (Defs defs) = ^^ separate (twice hardline) [id_command Val !valspecs; ref_command Val !valspecs; id_command Function !fundefs; - ref_command Function !fundefs] + ref_command Function !fundefs; + id_command Type !typedefs; + ref_command Type !typedefs] ^^ hardline diff --git a/src/lexer.mll b/src/lexer.mll index 57580e7a..1d48b82b 100644 --- a/src/lexer.mll +++ b/src/lexer.mll @@ -182,6 +182,7 @@ let kw_table = ("nondet", (fun x -> Nondet)); ("escape", (fun x -> Escape)); ("configuration", (fun _ -> Configuration)); + ("termination_measure", (fun _ -> TerminationMeasure)); ] diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 4bb1876c..dd0f7afd 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -3418,7 +3418,7 @@ let rec sets_from_assert e = match e with | E_app (Id_aux (Id "or_bool",_),[e1;e2]) -> aux e1 @ aux e2 - | E_app (Id_aux (Id "eq_atom",_), + | E_app (Id_aux (Id "eq_int",_), [E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_); E_aux (E_lit (L_aux (L_num i,_)),_)]) -> (check_kid kid; [i]) @@ -3930,16 +3930,46 @@ end module BitvectorSizeCasts = struct -let simplify_size_nexp env quant_kids (Nexp_aux (_,l) as nexp) = - match solve env nexp with - | Some n -> Some (nconstant n) - | None -> - let is_equal kid = - prove env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown)) - in - match List.find is_equal quant_kids with - | kid -> Some (Nexp_aux (Nexp_var kid,Generated l)) - | exception Not_found -> None +let simplify_size_nexp env quant_kids nexp = + let rec aux (Nexp_aux (ne,l) as nexp) = + match solve env nexp with + | Some n -> Some (nconstant n) + | None -> + let is_equal kid = + prove env (NC_aux (NC_equal (Nexp_aux (Nexp_var kid,Unknown), nexp),Unknown)) + in + match List.find is_equal quant_kids with + | kid -> Some (Nexp_aux (Nexp_var kid,Generated l)) + | exception Not_found -> + (* Normally rewriting of complex nexps in function signatures will + produce a simple constant or variable above, but occasionally it's + useful to work when that rewriting hasn't been applied. In + particular, that rewriting isn't fully working with RISC-V at the + moment. *) + let re f = function + | Some n1, Some n2 -> Some (Nexp_aux (f n1 n2,l)) + | _ -> None + in + match ne with + | Nexp_times(n1,n2) -> + re (fun n1 n2 -> Nexp_times(n1,n2)) (aux n1, aux n2) + | Nexp_sum(n1,n2) -> + re (fun n1 n2 -> Nexp_sum(n1,n2)) (aux n1, aux n2) + | Nexp_minus(n1,n2) -> + re (fun n1 n2 -> Nexp_times(n1,n2)) (aux n1, aux n2) + | Nexp_exp n -> + Util.option_map (fun n -> Nexp_aux (Nexp_exp n,l)) (aux n) + | Nexp_neg n -> + Util.option_map (fun n -> Nexp_aux (Nexp_neg n,l)) (aux n) + | _ -> None + in aux nexp + +let specs_required = ref IdSet.empty +let check_for_spec env name = + let id = mk_id name in + match Env.get_val_spec id env with + | _ -> () + | exception _ -> specs_required := IdSet.add id !specs_required (* These functions add cast functions across case splits, so that when a bitvector size becomes known in sail, the generated Lem code contains a @@ -3969,7 +3999,8 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ = [A_aux (A_nexp size',l_size'); t_ord; A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_) as t_bit]) -> begin match simplify_size_nexp env quant_kids size, simplify_size_nexp env quant_kids size' with - | Some size, Some size' when Nexp.compare size size' <> 0 -> + | Some size, Some size' -> + if Nexp.compare size size' <> 0 then let var = fresh () in let tar_typ' = Typ_aux (Typ_app (t_id, [A_aux (A_nexp size',l_size');t_ord;t_bit]), tar_l) in @@ -3980,6 +4011,10 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ = E_aux (E_app (Id_aux (Id cast_name, genunk), [E_aux (E_id var, (genunk, src_ann))]), (genunk, tar_ann))), (genunk, tar_ann)) + else + let var = fresh () in + P_aux (P_id var,(Generated src_l,src_ann)), + E_aux (E_id var,(Generated src_l,tar_ann)) | _ -> let var = fresh () in P_aux (P_id var,(Generated src_l,src_ann)), @@ -3995,6 +4030,7 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ = let pat, e' = aux src_typ' target_typ' in match !at_least_one with | Some one_target_typ -> begin + check_for_spec env cast_name; let src_ann = mk_tannot env src_typ no_effect in let tar_ann = mk_tannot env target_typ no_effect in match src_typ' with @@ -4028,13 +4064,55 @@ let make_bitvector_env_casts env quant_kids (kid,i) exp = Bindings.fold (fun var (mut,typ) exp -> if mut = Immutable then mk_cast var typ exp else exp) locals exp -let make_bitvector_cast_exp cast_name env quant_kids typ target_typ exp = (snd (make_bitvector_cast_fns cast_name env quant_kids typ target_typ)) exp +let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp = + let infer_arg_typ env f l typ = + let (typq, ctor_typ) = Env.get_union_id f env in + let quants = quant_items typq in + match Env.expand_synonyms env ctor_typ with + | Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _) -> + begin + let goals = quant_kopts typq |> List.map kopt_kid |> KidSet.of_list in + let unifiers = unify l env goals ret_typ typ in + let arg_typ' = subst_unifiers unifiers arg_typ in + arg_typ' + end + | _ -> typ_error l ("Malformed constructor " ^ string_of_id f ^ " with type " ^ string_of_typ ctor_typ) + + in + (* Push the cast down, including through constructors *) + let rec aux exp (typ, target_typ) = + let exp_env = env_of exp in + match exp with + | E_aux (E_let (lb,exp'),ann) -> + E_aux (E_let (lb,aux exp' (typ, target_typ)),ann) + | E_aux (E_block exps,ann) -> + let exps' = match List.rev exps with + | [] -> [] + | final::l -> aux final (typ, target_typ)::l + in E_aux (E_block (List.rev exps'),ann) + | E_aux (E_tuple exps,(l,ann)) -> begin + match Env.expand_synonyms exp_env typ, Env.expand_synonyms exp_env target_typ with + | Typ_aux (Typ_tup src_typs,_), Typ_aux (Typ_tup tgt_typs,_) -> + E_aux (E_tuple (List.map2 aux exps (List.combine src_typs tgt_typs)),(l,ann)) + | _ -> raise (Reporting.err_unreachable l __POS__ + ("Attempted to insert cast on tuple on non-tuple type: " ^ + string_of_typ typ ^ " to " ^ string_of_typ target_typ)) + end + | E_aux (E_app (f,args),(l,ann)) when Env.is_union_constructor f (env_of exp) -> + let arg = match args with [arg] -> arg | _ -> E_aux (E_tuple args, (l,empty_tannot)) in + let src_arg_typ = infer_arg_typ (env_of exp) f l typ in + let tgt_arg_typ = infer_arg_typ (env_of exp) f l target_typ in + E_aux (E_app (f,[aux arg (src_arg_typ, tgt_arg_typ)]),(l,ann)) + | _ -> + (snd (make_bitvector_cast_fns cast_name cast_env quant_kids typ target_typ)) exp + in + aux exp (typ, target_typ) let rec extract_value_from_guard var (E_aux (e,_)) = match e with | E_app (op, ([E_aux (E_id var',_); E_aux (E_lit (L_aux (L_num i,_)),_)] | [E_aux (E_lit (L_aux (L_num i,_)),_); E_aux (E_id var',_)])) - when string_of_id op = "eq_atom" && Id.compare var var' == 0 -> + when string_of_id op = "eq_int" && Id.compare var var' == 0 -> Some i | E_app (op, [e1;e2]) when string_of_id op = "and_bool" -> (match extract_value_from_guard var e1 with @@ -4047,7 +4125,8 @@ let fill_in_type env typ = let subst = KidSet.fold (fun kid subst -> match Env.get_typ_var kid env with | K_type - | K_order -> subst + | K_order + | K_bool -> subst | K_int -> (match solve env (nvar kid) with | None -> subst @@ -4076,13 +4155,14 @@ let add_bitvector_casts (Defs defs) = let pat,guard,body,ann = destruct_pexp pexp in let body = match pat, guard with | P_aux (P_lit (L_aux (L_num i,_)),_), _ -> - let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in + (* We used to just substitute kid, but fill_in_type also catches other kids defined by it *) + let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ (make_bitvector_env_casts env quant_kids (kid,i) body) | P_aux (P_id var,_), Some guard -> (match extract_value_from_guard var guard with | Some i -> - let src_typ = subst_src_typ (KBindings.singleton kid (nconstant i)) result_typ in + let src_typ = fill_in_type (Env.add_constraint (nc_eq (nvar kid) (nconstant i)) env) result_typ in make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ (make_bitvector_env_casts env quant_kids (kid,i) body) | None -> body) @@ -4102,10 +4182,17 @@ let add_bitvector_casts (Defs defs) = | E_app (op, ([E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_); y] | [y; E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_)])) - when string_of_id op = "eq_atom" -> + when string_of_id op = "eq_int" -> (match destruct_atom_nexp (env_of y) (typ_of y) with | Some (Nexp_aux (Nexp_constant i,_)) -> [(kid,i)] | _ -> []) + | E_app (op,[x;y]) + when string_of_id op = "eq_int" -> + (match destruct_atom_nexp (env_of x) (typ_of x), destruct_atom_nexp (env_of y) (typ_of y) with + | Some (Nexp_aux (Nexp_var kid,_)), Some (Nexp_aux (Nexp_constant i,_)) + | Some (Nexp_aux (Nexp_constant i,_)), Some (Nexp_aux (Nexp_var kid,_)) + -> [(kid,i)] + | _ -> []) | E_app (op, [x;y]) when string_of_id op = "and_bool" -> extract x @ extract y | _ -> [] @@ -4120,10 +4207,16 @@ let add_bitvector_casts (Defs defs) = E_aux (E_if (e1,e2',e3), ann) | E_return e' -> E_aux (E_return (make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann) - | E_assign (LEXP_aux (lexp,lexp_annot),e') -> - E_aux (E_assign (LEXP_aux (lexp,lexp_annot), - make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) - (typ_of_annot lexp_annot) e'),ann) + | E_assign (LEXP_aux (_,lexp_annot) as lexp,e') -> begin + (* The type in the lexp_annot might come from e' rather than being the + type of the storage, so ask the type checker what it really is. *) + match infer_lexp (env_of_annot lexp_annot) (strip_lexp lexp) with + | LEXP_aux (_,lexp_annot') -> + E_aux (E_assign (lexp, + make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) + (typ_of_annot lexp_annot') e'),ann) + | exception _ -> E_aux (e,ann) + end | E_id id -> begin let env = env_of_annot ann in match Env.lookup_id id env with @@ -4167,7 +4260,23 @@ let add_bitvector_casts (Defs defs) = | DEF_fundef (FD_aux (FD_function (r,t,e,fcls),fd_ann)) -> DEF_fundef (FD_aux (FD_function (r,t,e,List.map rewrite_funcl fcls),fd_ann)) | d -> d - in Defs (List.map rewrite_def defs) + in + specs_required := IdSet.empty; + let defs = List.map rewrite_def defs in + let l = Generated Unknown in + let Defs cast_specs,_ = + (* TODO: use default/relevant order *) + let kid = mk_kid "n" in + let bitsn = vector_typ (nvar kid) dec_ord bit_typ in + let ts = mk_typschm (mk_typquant [mk_qi_id K_int kid]) + (function_typ [bitsn] bitsn no_effect) in + let extfn _ = Some "zeroExtend" in + let mkfn name = + mk_val_spec (VS_val_spec (ts,name,extfn,false)) + in + let defs = List.map mkfn (IdSet.elements !specs_required) in + check Env.empty (Defs defs) + in Defs (cast_specs @ defs) end let replace_nexp_in_typ env typ orig new_nexp = @@ -4243,14 +4352,13 @@ let rewrite_toplevel_nexps (Defs defs) = let nexp_map, typ = rewrite_typ_in_spec env nexp_map typ in (nexp_map, typ::t)) typs (nexp_map,[]) in nexp_map, Typ_aux (Typ_tup typs,ann) - | _ -> - let typ' = Env.base_typ_of env typ_full in + | _ when is_number typ_full || is_bitvector_typ typ_full -> begin let nexp_opt = - match destruct_atom_nexp env typ' with + match destruct_atom_nexp env typ_full with | Some nexp -> Some nexp | None -> - if is_bitvector_typ typ' then - let (size,_,_) = vector_typ_args_of typ' in + if is_bitvector_typ typ_full then + let (size,_,_) = vector_typ_args_of typ_full in Some size else None in match nexp_opt with @@ -4266,10 +4374,27 @@ let rewrite_toplevel_nexps (Defs defs) = (kid, nexp)::nexp_map, kid in let new_nexp = nvar kid in - (* Try to avoid expanding the original type *) - let changed, typ = replace_nexp_in_typ env typ_full nexp new_nexp in - if changed then nexp_map, typ - else nexp_map, snd (replace_nexp_in_typ env typ' nexp new_nexp) + nexp_map, snd (replace_nexp_in_typ env typ_full nexp new_nexp) + end + | _ -> + let typ' = Env.base_typ_of env typ_full in + if Typ.compare typ_full typ' == 0 then + match t with + | Typ_app (f,args) -> + let in_arg nexp_map (A_aux (arg,l) as arg_full) = + match arg with + | A_typ typ -> + let nexp_map, typ' = rewrite_typ_in_spec env nexp_map typ in + nexp_map, A_aux (A_typ typ',l) + | A_bool _ | A_nexp _ | A_order _ -> nexp_map, arg_full + in + let nexp_map, args = + List.fold_right (fun arg (nexp_map,args) -> + let nexp_map, arg = in_arg nexp_map arg in + (nexp_map, arg::args)) args (nexp_map,[]) + in nexp_map, Typ_aux (Typ_app (f,args),ann) + | _ -> nexp_map, typ_full + else rewrite_typ_in_spec env nexp_map typ' in let rewrite_valspec (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (tqs,typ),ts_l),id,ext_opt,is_cast),ann)) = match tqs with @@ -4337,7 +4462,11 @@ let rewrite_toplevel_nexps (Defs defs) = | DEF_spec vs -> (match rewrite_valspec vs with | None -> spec_map, def | Some (id, nexp_map, vs) -> Bindings.add id nexp_map spec_map, DEF_spec vs) - | DEF_fundef (FD_aux (FD_function (recopt,tann,eff,funcls),ann)) -> + | DEF_fundef (FD_aux (FD_function (recopt,_,eff,funcls),ann)) -> + (* Type annotations on function definitions will have been turned into + valspecs by type checking, so it should be safe to drop them rather + than updating them. *) + let tann = Typ_annot_opt_aux (Typ_annot_opt_none,Generated Unknown) in spec_map, DEF_fundef (FD_aux (FD_function (recopt,tann,eff,List.map (rewrite_funcl spec_map) funcls),ann)) | _ -> spec_map, def diff --git a/src/optimize.ml b/src/optimize.ml new file mode 100644 index 00000000..a372abf4 --- /dev/null +++ b/src/optimize.ml @@ -0,0 +1,100 @@ +(**************************************************************************) +(* Sail *) +(* *) +(* Copyright (c) 2013-2017 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* Alasdair Armstrong *) +(* Brian Campbell *) +(* Thomas Bauereiss *) +(* Anthony Fox *) +(* Jon French *) +(* Dominic Mulligan *) +(* Stephen Kell *) +(* Mark Wassell *) +(* *) +(* All rights reserved. *) +(* *) +(* This software was developed by the University of Cambridge Computer *) +(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) +(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(**************************************************************************) + +open Ast +open Ast_util +open Rewriter + +let recheck (Defs defs) = + let defs = Type_check.check_with_envs Type_check.initial_env defs in + + let rec find_optimizations = function + | ([DEF_pragma ("optimize", pragma, p_l)], env) + :: ([DEF_spec vs as def1], _) + :: ([DEF_fundef fdef as def2], _) + :: defs -> + let id = id_of_val_spec vs in + let args = Str.split (Str.regexp " +") (String.trim pragma) in + begin match args with + | ["unroll"; n]-> + let n = int_of_string n in + + let rw_app subst (fn, args) = + if Id.compare id fn = 0 then E_app (subst, args) else E_app (fn, args) + in + let rw_exp subst = { id_exp_alg with e_app = rw_app subst } in + let rw_defs subst = { rewriters_base with rewrite_exp = (fun _ -> fold_exp (rw_exp subst)) } in + + let specs = ref [def1] in + let bodies = ref [rewrite_def (rw_defs (append_id id "_unroll_1")) def2] in + + for i = 1 to n do + let current_id = append_id id ("_unroll_" ^ string_of_int i) in + let next_id = if i = n then current_id else append_id id ("_unroll_" ^ string_of_int (i + 1)) in + (* Create a valspec for the new unrolled function *) + specs := !specs @ [DEF_spec (rename_valspec current_id vs)]; + (* Then duplicate it's function body and make it call the next unrolled function *) + bodies := !bodies @ [rewrite_def (rw_defs next_id) (DEF_fundef (rename_fundef current_id fdef))] + done; + + !specs @ !bodies @ find_optimizations defs + + | _ -> + Util.warn ("Unrecognised optimize pragma in this context: " ^ pragma); + def1 :: def2 :: find_optimizations defs + end + + | (defs, _) :: defs' -> + defs @ find_optimizations defs' + + | [] -> [] + in + + Defs (find_optimizations defs) diff --git a/src/parse_ast.ml b/src/parse_ast.ml index f3bb28db..c47ca931 100644 --- a/src/parse_ast.ml +++ b/src/parse_ast.ml @@ -531,6 +531,7 @@ def = (* Top-level definition *) | DEF_spec of val_spec (* top-level type constraint *) | DEF_default of default_typing_spec (* default kind and type assumptions *) | DEF_scattered of scattered_def (* scattered definition *) + | DEF_measure of id * pat * exp (* separate termination measure declaration *) | DEF_reg_dec of dec_spec (* register declaration *) | DEF_pragma of string * string * l | DEF_internal_mutrec of fundef list diff --git a/src/parser.mly b/src/parser.mly index 9fdf27b7..abf533c3 100644 --- a/src/parser.mly +++ b/src/parser.mly @@ -183,7 +183,7 @@ let rec desugar_rchain chain s e = %token Pure Register Return Scattered Sizeof Struct Then True TwoCaret TYPE Typedef %token Undefined Union Newtype With Val Constant Constraint Throw Try Catch Exit Bitfield %token Barr Depend Rreg Wreg Rmem Rmemt Wmem Wmv Wmvt Eamem Exmem Undef Unspec Nondet Escape -%token Repeat Until While Do Mutual Var Ref Configuration +%token Repeat Until While Do Mutual Var Ref Configuration TerminationMeasure %nonassoc Then %nonassoc Else @@ -1430,6 +1430,8 @@ def: { DEF_internal_mutrec $3 } | Pragma { DEF_pragma (fst $1, snd $1, loc $startpos $endpos) } + | TerminationMeasure id pat Eq exp + { DEF_measure ($2, $3, $5) } defs_list: | def diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index 279a8182..b5d72807 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -239,9 +239,9 @@ let doc_nexp ctx ?(skip_vars=KidSet.empty) nexp = and app (Nexp_aux (n,l) as nexp) = match n with | Nexp_app (Id_aux (Id "div",_), [n1;n2]) - -> separate space [string "Z.quot"; atomic n1; atomic n2] + -> separate space [string "ZEuclid.div"; atomic n1; atomic n2] | Nexp_app (Id_aux (Id "mod",_), [n1;n2]) - -> separate space [string "Z.rem"; atomic n1; atomic n2] + -> separate space [string "ZEuclid.modulo"; atomic n1; atomic n2] | Nexp_app (Id_aux (Id "abs_atom",_), [n1]) -> separate space [string "Z.abs"; atomic n1] | _ -> atomic nexp @@ -585,8 +585,9 @@ let doc_lit (L_aux(lit,l)) = | L_false -> utf8string "false" | L_true -> utf8string "true" | L_num i -> - let ipp = Big_int.to_string i in - utf8string ipp + let s = Big_int.to_string i in + let ipp = utf8string s in + if Big_int.less i Big_int.zero then parens ipp else ipp | L_hex n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0x" ^ n) ^ ")" (*shouldn't happen*)*) | L_bin n -> failwith "Shouldn't happen" (*"(num_to_vec " ^ ("0b" ^ n) ^ ")" (*shouldn't happen*)*) | L_undef -> @@ -1296,7 +1297,7 @@ let doc_exp, doc_let = [parens (string "_limit_reduces _acc")] else match f with | Id_aux (Id x,_) when is_prefix "#rec#" x -> - main_call @ [parens (string "Zwf_well_founded _ _")] + main_call @ [parens (string "Zwf_guarded _")] | _ -> main_call in hang 2 (flow (break 1) all) in diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index 3d4f77e6..16c338bd 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -660,6 +660,9 @@ let rec doc_def def = group (match def with ^^ hardline ^^ string "}" | DEF_reg_dec dec -> doc_dec dec | DEF_scattered sdef -> doc_scattered sdef + | DEF_measure (id,pat,exp) -> + string "termination_measure" ^^ space ^^ doc_id id ^/^ doc_pat pat ^^ + space ^^ equals ^/^ doc_exp exp | DEF_pragma (pragma, arg, l) -> string ("$" ^ pragma ^ " " ^ arg) | DEF_fixity (prec, n, id) -> diff --git a/src/process_file.ml b/src/process_file.ml index 87acd83a..e8bb5fc1 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -51,6 +51,10 @@ open PPrint open Pretty_print_common +let opt_lem_output_dir = ref None +let opt_isa_output_dir = ref None +let opt_coq_output_dir = ref None + type out_type = | Lem_out of string list | Coq_out of string list @@ -254,19 +258,24 @@ let check_ast (env : Type_check.Env.t) (defs : unit Ast.defs) : Type_check.tanno (ast, env) -let open_output_with_check file_name = +let open_output_with_check opt_dir file_name = let (temp_file_name, o) = Filename.open_temp_file "ll_temp" "" in let o' = Format.formatter_of_out_channel o in - (o', (o, temp_file_name, file_name)) + (o', (o, temp_file_name, opt_dir, file_name)) -let open_output_with_check_unformatted file_name = +let open_output_with_check_unformatted opt_dir file_name = let (temp_file_name, o) = Filename.open_temp_file "ll_temp" "" in - (o, temp_file_name, file_name) + (o, temp_file_name, opt_dir, file_name) let always_replace_files = ref true -let close_output_with_check (o, temp_file_name, file_name) = +let close_output_with_check (o, temp_file_name, opt_dir, file_name) = let _ = close_out o in + let file_name = match opt_dir with + | None -> file_name + | Some dir -> if Sys.file_exists dir then () + else Unix.mkdir dir 0o775; + Filename.concat dir file_name in let do_replace = !always_replace_files || (not (Util.same_content_files temp_file_name file_name)) in let _ = if (not do_replace) then Sys.remove temp_file_name else Util.move_file temp_file_name file_name in @@ -308,22 +317,22 @@ let output_lem filename libs defs = string "end" ] ^^ hardline in - let ((ot,_, _) as ext_ot) = - open_output_with_check_unformatted (filename ^ "_types" ^ ".lem") in - let ((o,_, _) as ext_o) = - open_output_with_check_unformatted (filename ^ ".lem") in + let ((ot,_,_,_) as ext_ot) = + open_output_with_check_unformatted !opt_lem_output_dir (filename ^ "_types" ^ ".lem") in + let ((o,_,_,_) as ext_o) = + open_output_with_check_unformatted !opt_lem_output_dir (filename ^ ".lem") in (Pretty_print.pp_defs_lem (ot, base_imports) (o, base_imports @ (String.capitalize_ascii types_module :: libs)) defs generated_line); close_output_with_check ext_ot; close_output_with_check ext_o; - let ((ol, _, _) as ext_ol) = - open_output_with_check_unformatted (isa_thy_name ^ ".thy") in + let ((ol,_,_,_) as ext_ol) = + open_output_with_check_unformatted !opt_isa_output_dir (isa_thy_name ^ ".thy") in print ol isa_lemmas; close_output_with_check ext_ol -let output_coq filename libs defs = +let output_coq opt_dir filename libs defs = let generated_line = generated_line filename in let types_module = (filename ^ "_types") in let monad_modules = ["Sail2_prompt_monad"; "Sail2_prompt"; "Sail2_state"] in @@ -336,10 +345,10 @@ let output_coq filename libs defs = operators_module ] @ monad_modules in - let ((ot,_, _) as ext_ot) = - open_output_with_check_unformatted (filename ^ "_types" ^ ".v") in - let ((o,_, _) as ext_o) = - open_output_with_check_unformatted (filename ^ ".v") in + let ((ot,_,_,_) as ext_ot) = + open_output_with_check_unformatted opt_dir (filename ^ "_types" ^ ".v") in + let ((o,_,_,_) as ext_o) = + open_output_with_check_unformatted opt_dir (filename ^ ".v") in (Pretty_print_coq.pp_defs_coq (ot, base_imports) (o, base_imports @ (types_module :: libs)) @@ -357,7 +366,7 @@ let output1 libpath out_arg filename defs = | Lem_out libs -> output_lem f' libs defs | Coq_out libs -> - output_coq f' libs defs + output_coq !opt_coq_output_dir f' libs defs let output libpath out_arg files = List.iter @@ -374,7 +383,7 @@ let rewrite_step defs (name, rewriter) = begin let filename = f ^ "_rewrite_" ^ string_of_int i ^ "_" ^ name ^ ".sail" in (* output "" Lem_ast_out [filename, defs]; *) - let ((ot,_, _) as ext_ot) = open_output_with_check_unformatted filename in + let ((ot,_,_,_) as ext_ot) = open_output_with_check_unformatted None filename in Pretty_print_sail.pp_defs ot defs; close_output_with_check ext_ot; opt_ddump_rewrite_ast := Some (f, i + 1) diff --git a/src/process_file.mli b/src/process_file.mli index 7b860a73..7371b299 100644 --- a/src/process_file.mli +++ b/src/process_file.mli @@ -71,6 +71,10 @@ val opt_ddump_tc_ast : bool ref val opt_ddump_rewrite_ast : ((string * int) option) ref val opt_dno_cast : bool ref +val opt_lem_output_dir : (string option) ref +val opt_isa_output_dir : (string option) ref +val opt_coq_output_dir : (string option) ref + type out_type = | Lem_out of string list (* If present, the strings are files to open in the lem backend*) | Coq_out of string list (* If present, the strings are files to open in the coq backend*) diff --git a/src/rewriter.ml b/src/rewriter.ml index a70f6fab..21310b91 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -364,6 +364,7 @@ let rewrite_def rewriters d = match d with | DEF_val letbind -> DEF_val (rewriters.rewrite_let rewriters letbind) | DEF_pragma (pragma, arg, l) -> DEF_pragma (pragma, arg, l) | DEF_scattered _ -> raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "DEF_scattered survived to rewritter") + | DEF_measure (id,pat,exp) -> DEF_measure (id,rewriters.rewrite_pat rewriters pat, rewriters.rewrite_exp rewriters exp) let rewrite_defs_base rewriters (Defs defs) = let rec rewrite ds = match ds with diff --git a/src/rewrites.ml b/src/rewrites.ml index 10bc4f44..284c7d67 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -4763,6 +4763,35 @@ let minimise_recursive_functions (Defs defs) = | d -> d in Defs (List.map rewrite_def defs) +let move_termination_measures (Defs defs) = + let scan_for id defs = + let rec aux = function + | [] -> None + | (DEF_measure (id',pat,exp))::t -> + if Id.compare id id' == 0 then Some (pat,exp) else aux t + | (DEF_fundef (FD_aux (FD_function (_,_,_,FCL_aux (FCL_Funcl (id',_),_)::_),_)))::_ + | (DEF_spec (VS_aux (VS_val_spec (_,id',_,_),_))::_) + when Id.compare id id' == 0 -> None + | _::t -> aux t + in aux defs + in + let rec aux acc = function + | [] -> List.rev acc + | (DEF_fundef (FD_aux (FD_function (r,ty,e,fs),(l,f_ann))) as d)::t -> begin + let id = match fs with + | [] -> assert false (* TODO *) + | (FCL_aux (FCL_Funcl (id,_),_))::_ -> id + in + match scan_for id t with + | None -> aux (d::acc) t + | Some (pat,exp) -> + let r = Rec_aux (Rec_measure (pat,exp), Generated l) in + aux (DEF_fundef (FD_aux (FD_function (r,ty,e,fs),(l,f_ann)))::acc) t + end + | (DEF_measure _)::t -> aux acc t + | h::t -> aux (h::acc) t + in Defs (aux [] defs) + (* Make recursive functions with a measure use the measure as an explicit recursion limit, enforced by an assertion. *) let rewrite_explicit_measure (Defs defs) = @@ -4806,8 +4835,8 @@ let rewrite_explicit_measure (Defs defs) = | exception Not_found -> [vs] in (* Add extra argument and assertion to each funcl, and rewrite recursive calls *) - let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),ann) as fcl) = - let loc = Parse_ast.Generated (fst ann) in + let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann) as fcl) = + let loc = Parse_ast.Generated (fst fcl_ann) in let P_aux (pat,pann),guard,body,ann = destruct_pexp pexp in let extra_pat = P_aux (P_id limit,(loc,empty_tannot)) in let pat = match pat with @@ -4839,7 +4868,7 @@ let rewrite_explicit_measure (Defs defs) = } body in let body = E_aux (E_block [assert_exp; body],(loc,empty_tannot)) in - FCL_aux (FCL_Funcl (rec_id id, construct_pexp (P_aux (pat,pann),guard,body,ann)),ann) + FCL_aux (FCL_Funcl (rec_id id, construct_pexp (P_aux (pat,pann),guard,body,ann)),fcl_ann) in let rewrite_function (FD_aux (FD_function (r,t,e,fcls),ann) as fd) = let loc = Parse_ast.Generated (fst ann) in @@ -4881,6 +4910,7 @@ let rewrite_explicit_measure (Defs defs) = | [wpat] -> wpat | _ -> P_aux (P_tup wpats,(loc,empty_tannot)) in + let measure_exp = E_aux (E_cast (int_typ, measure_exp),(loc,empty_tannot)) in let wbody = E_aux (E_app (rec_id id,wexps@[measure_exp]),(loc,empty_tannot)) in let wrapper = FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (wpat,wbody),(loc,empty_tannot))),(loc,empty_tannot)) @@ -4951,6 +4981,10 @@ let if_mono f defs = | [], false -> defs | _, _ -> f defs +(* Also turn mwords stages on when we're just trying out mono *) +let if_mwords f defs = + if !Pretty_print_lem.opt_mwords then f defs else if_mono f defs + let rewrite_defs_lem = [ ("realise_mappings", rewrite_defs_realise_mappings); ("remove_mapping_valspecs", remove_mapping_valspecs); @@ -4961,10 +4995,10 @@ let rewrite_defs_lem = [ ("recheck_defs", if_mono recheck_defs); ("rewrite_toplevel_nexps", if_mono rewrite_toplevel_nexps); ("monomorphise", if_mono monomorphise); - ("recheck_defs", if_mono recheck_defs); - ("add_bitvector_casts", if_mono Monomorphise.add_bitvector_casts); + ("recheck_defs", if_mwords recheck_defs); + ("add_bitvector_casts", if_mwords Monomorphise.add_bitvector_casts); ("rewrite_atoms_to_singletons", if_mono Monomorphise.rewrite_atoms_to_singletons); - ("recheck_defs", if_mono recheck_defs); + ("recheck_defs", if_mwords recheck_defs); ("rewrite_undefined", rewrite_undefined_if_gen false); ("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list); ("remove_not_pats", rewrite_defs_not_pats); @@ -5029,10 +5063,15 @@ let rewrite_defs_coq = [ ("exp_lift_assign", rewrite_defs_exp_lift_assign); (* ("constraint", rewrite_constraint); *) (* ("remove_assert", rewrite_defs_remove_assert); *) + ("move_termination_measures", move_termination_measures); ("top_sort_defs", top_sort_defs); ("trivial_sizeof", rewrite_trivial_sizeof); ("sizeof", rewrite_sizeof); ("early_return", rewrite_defs_early_return); + (* merge funcls before adding the measure argument so that it doesn't + disappear into an internal pattern match *) + ("merge function clauses", merge_funcls); + ("recheck_defs_without_effects", recheck_defs_without_effects); ("make_cases_exhaustive", MakeExhaustive.rewrite); ("rewrite_explicit_measure", rewrite_explicit_measure); ("recheck_defs_without_effects", recheck_defs_without_effects); @@ -5043,7 +5082,6 @@ let rewrite_defs_coq = [ ("internal_lets", rewrite_defs_internal_lets); ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds); ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns); - ("merge function clauses", merge_funcls); ("recheck_defs", recheck_defs) ] @@ -5095,7 +5133,7 @@ let rewrite_defs_c = [ ("trivial_sizeof", rewrite_trivial_sizeof); ("sizeof", rewrite_sizeof); ("merge_function_clauses", merge_funcls); - ("recheck_defs", recheck_defs) + ("recheck_defs", Optimize.recheck) ] let rewrite_defs_interpreter = [ diff --git a/src/sail.ml b/src/sail.ml index 9f2c7310..c5d69aa5 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -116,6 +116,9 @@ let options = Arg.align ([ ( "-marshal", Arg.Set opt_marshal_defs, " OCaml-marshal out the rewritten AST to a file"); + ( "-latex_full_valspecs", + Arg.Clear Latex.opt_simple_val, + " print full valspecs in latex output latex"); ( "-c", Arg.Tuple [Arg.Set opt_print_c; Arg.Set Initial_check.opt_undefined_gen], " output a C translated version of the input"); @@ -154,6 +157,12 @@ let options = Arg.align ([ ( "-lem", Arg.Set opt_print_lem, " output a Lem translated version of the input"); + ( "-lem_output_dir", + Arg.String (fun dir -> Process_file.opt_lem_output_dir := Some dir), + " set a custom directory to output generated Lem"); + ( "-isa_output_dir", + Arg.String (fun dir -> Process_file.opt_isa_output_dir := Some dir), + " set a custom directory to output generated Isabelle auxiliary theories"); ( "-lem_lib", Arg.String (fun l -> opt_libs_lem := l::!opt_libs_lem), "<filename> provide additional library to open in Lem output"); @@ -166,6 +175,9 @@ let options = Arg.align ([ ( "-coq", Arg.Set opt_print_coq, " output a Coq translated version of the input"); + ( "-coq_output_dir", + Arg.String (fun dir -> Process_file.opt_coq_output_dir := Some dir), + " set a custom directory to output generated Coq"); ( "-coq_lib", Arg.String (fun l -> opt_libs_coq := l::!opt_libs_coq), "<filename> provide additional library to open in Coq output"); diff --git a/src/sail_lib.ml b/src/sail_lib.ml index c0bf80fa..d1a21b73 100644 --- a/src/sail_lib.ml +++ b/src/sail_lib.ml @@ -508,6 +508,13 @@ let read_ram (addr_size, data_size, hex_ram, addr) = Bytes.iter (fun byte -> vector := (byte_of_int (int_of_char byte)) @ !vector) bytes; !vector +let fast_read_ram (data_size, addr) = + let addr = uint addr in + let bytes = read_mem_bytes addr (Big_int.to_int data_size) in + let vector = ref [] in + Bytes.iter (fun byte -> vector := (byte_of_int (int_of_char byte)) @ !vector) bytes; + !vector + let tag_ram = (ref Mem.empty : (bool Mem.t) ref);; let write_tag_bool (addr, tag) = diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml index 940fbfe5..398f20b5 100644 --- a/src/spec_analysis.ml +++ b/src/spec_analysis.ml @@ -356,9 +356,10 @@ let fv_of_fun consider_var (FD_aux (FD_function(rec_opt,tannot_opt,_,funcls),_) | [] -> failwith "fv_of_fun fell off the end looking for the function name" | FCL_aux(FCL_Funcl(id,_),_)::_ -> string_of_id id in let base_bounds = match rec_opt with - (* Current Sail does not have syntax for declaring functions as recursive, + (* Current Sail does not require syntax for declaring functions as recursive, and type checker does not check whether functions are recursive, so - just always add a self-dependency of functions on themselves + just always add a self-dependency of functions on themselves, as well as + adding dependencies from any specified termination measure further below | Rec_aux(Ast.Rec_rec,_) -> init_env fun_name | _ -> mt*) | _ -> init_env fun_name in @@ -369,6 +370,13 @@ let fv_of_fun consider_var (FD_aux (FD_function(rec_opt,tannot_opt,_,funcls),_) bound, fv_of_typ consider_var bound mt typ | Typ_annot_opt_aux(Typ_annot_opt_none, _) -> base_bounds, mt in + let ns_measure = match rec_opt with + | Rec_aux(Rec_measure (pat,exp),_) -> + let pat_bs,pat_ns = pat_bindings consider_var base_bounds mt pat in + let _, exp_ns,_ = fv_of_exp consider_var pat_bs pat_ns Nameset.empty exp in + exp_ns + | _ -> mt + in let ns = List.fold_right (fun (FCL_aux(FCL_Funcl(_,pexp),_)) ns -> match pexp with | Pat_aux(Pat_exp (pat,exp),_) -> @@ -383,7 +391,7 @@ let fv_of_fun consider_var (FD_aux (FD_function(rec_opt,tannot_opt,_,funcls),_) ) funcls mt in let ns_vs = init_env ("val:" ^ (string_of_id (id_of_fundef fd))) in (* let _ = Printf.eprintf "Function %s uses %s\n" fun_name (set_to_string (Nameset.union ns ns_r)) in *) - init_env fun_name, Nameset.union ns_vs (Nameset.union ns ns_r) + init_env fun_name, Nameset.union ns_vs (Nameset.union ns (Nameset.union ns_r ns_measure)) let fv_of_vspec consider_var (VS_aux(vspec,_)) = match vspec with | VS_val_spec(ts,id,_,_) -> @@ -499,6 +507,7 @@ let fv_of_def consider_var consider_scatter_as_one all_defs = function | DEF_scattered sdef -> fv_of_scattered consider_var consider_scatter_as_one all_defs sdef | DEF_reg_dec rdec -> fv_of_rd consider_var rdec | DEF_pragma _ -> mt,mt + | DEF_measure _ -> mt,mt (* currently removed beforehand *) let group_defs consider_scatter_as_one (Ast.Defs defs) = List.map (fun d -> (fv_of_def false consider_scatter_as_one defs d,d)) defs diff --git a/src/state.ml b/src/state.ml index c9a47b06..fe1cebe7 100644 --- a/src/state.ml +++ b/src/state.ml @@ -135,20 +135,20 @@ let generate_initial_regstate defs = let typ_subst_typquant tq args typ = List.fold_left2 typ_subst_quant_item typ (quant_items tq) args in - let add_typ_init_val vals = function + let add_typ_init_val (defs', vals) = function | TD_enum (id, _, id1 :: _, _) -> (* Choose the first value of an enumeration type as default *) - Bindings.add id (fun _ -> string_of_id id1) vals + (defs', Bindings.add id (fun _ -> string_of_id id1) vals) | TD_variant (id, _, tq, (Tu_aux (Tu_ty_id (typ1, id1), _)) :: _, _) -> (* Choose the first variant of a union type as default *) let init_val args = let typ1 = typ_subst_typquant tq args typ1 in string_of_id id1 ^ " (" ^ lookup_init_val vals typ1 ^ ")" in - Bindings.add id init_val vals + (defs', Bindings.add id init_val vals) | TD_abbrev (id, tq, A_aux (A_typ typ, _)) -> let init_val args = lookup_init_val vals (typ_subst_typquant tq args typ) in - Bindings.add id init_val vals + (defs', Bindings.add id init_val vals) | TD_record (id, _, tq, fields, _) -> let init_val args = let init_field (typ, id) = @@ -157,16 +157,21 @@ let generate_initial_regstate defs = in "struct { " ^ (String.concat ", " (List.map init_field fields)) ^ " }" in - Bindings.add id init_val vals + let def_name = "initial_" ^ string_of_id id in + if quant_items tq = [] && not (is_defined defs def_name) then + (defs' @ ["let " ^ def_name ^ " : " ^ string_of_id id ^ " = " ^ init_val []], + Bindings.add id (fun _ -> def_name) vals) + else (defs', Bindings.add id init_val vals) | TD_bitfield (id, typ, _) -> - Bindings.add id (fun _ -> lookup_init_val vals typ) vals - | _ -> vals + (defs', Bindings.add id (fun _ -> lookup_init_val vals typ) vals) + | _ -> (defs', vals) in - let init_vals = List.fold_left (fun vals def -> match def with - | DEF_type (TD_aux (td, _)) -> add_typ_init_val vals td - | _ -> vals) Bindings.empty defs + let (init_defs, init_vals) = List.fold_left (fun inits def -> match def with + | DEF_type (TD_aux (td, _)) -> add_typ_init_val inits td + | _ -> inits) ([], Bindings.empty) defs in let init_reg (typ, id) = string_of_id id ^ " = " ^ lookup_init_val init_vals typ in + init_defs @ ["let initial_regstate : regstate = struct { " ^ (String.concat ", " (List.map init_reg registers)) ^ " }"] with | _ -> [] (* Do not generate an initial register state if anything goes wrong *) diff --git a/src/type_check.ml b/src/type_check.ml index 63f03c81..63cb4829 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -1226,7 +1226,7 @@ let prove_z3 env (NC_aux (_, l) as nc) = | Constraint.Sat -> typ_debug (lazy "sat"); false | Constraint.Unknown -> typ_debug (lazy "unknown"); false -let solve env (Nexp_aux (_, l) as nexp) = +let solve env (Nexp_aux (_, l) as nexp) = typ_print (lazy (Util.("Solve " |> red |> clear) ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_nexp nexp ^ " = ?")); match nexp with @@ -1238,6 +1238,8 @@ let solve env (Nexp_aux (_, l) as nexp) = let constr = List.fold_left nc_and (nc_eq (nvar (mk_kid "solve#")) nexp) (Env.get_constraints env) in Constraint.solve_z3 l vars constr (mk_kid "solve#") + + let prove env nc = typ_print (lazy (Util.("Prove " |> red |> clear) ^ string_of_list ", " string_of_n_constraint (Env.get_constraints env) ^ " |- " ^ string_of_n_constraint nc)); let (NC_aux (nc_aux, _) as nc) = Env.expand_constraint_synonyms env nc in @@ -4141,6 +4143,22 @@ let check_tannotopt env typq ret_typ = function then () else typ_error l (string_of_bind (typq, ret_typ) ^ " and " ^ string_of_bind (annot_typq, annot_ret_typ) ^ " do not match between function and val spec") +let check_termination_measure env arg_typs pat exp = + let typ = match arg_typs with [x] -> x | _ -> Typ_aux (Typ_tup arg_typs,Unknown) in + let tpat, env = bind_pat_no_guard env (strip_pat pat) typ in + let texp = check_exp env (strip_exp exp) int_typ in + tpat, texp + +let check_termination_measure_decl env (id, pat, exp) = + let quant, typ = Env.get_val_spec id env in + let arg_typs, l = match typ with + | Typ_aux (Typ_fn (arg_typs, _ ,_),l) -> arg_typs,l + | _ -> typ_error (id_loc id) "Function val spec is not a function type" + in + let env = add_typquant l quant env in + let tpat, texp = check_termination_measure env arg_typs pat exp in + DEF_measure (id, tpat, texp) + let check_fundef env (FD_aux (FD_function (recopt, tannotopt, effectopt, funcls), (l, _)) as fd_aux) = let id = match (List.fold_right @@ -4173,9 +4191,7 @@ let check_fundef env (FD_aux (FD_function (recopt, tannotopt, effectopt, funcls) | Rec_aux (Rec_nonrec, l) -> Rec_aux (Rec_nonrec, l) | Rec_aux (Rec_rec, l) -> Rec_aux (Rec_rec, l) | Rec_aux (Rec_measure (measure_p, measure_e), l) -> - let typ = match vtyp_args with [x] -> x | _ -> Typ_aux (Typ_tup vtyp_args,Unknown) in - let tpat, env = bind_pat_no_guard funcl_env (strip_pat measure_p) typ in - let texp = check_exp env (strip_exp measure_e) int_typ in + let tpat, texp = check_termination_measure funcl_env vtyp_args measure_p measure_e in Rec_aux (Rec_measure (tpat, texp), l) in let funcls = List.map (fun funcl -> check_funcl funcl_env funcl typ) funcls in @@ -4241,14 +4257,18 @@ let check_val_spec env (VS_aux (vs, (l, _))) = typ_print (lazy (Util.("Check val spec " |> cyan |> clear) ^ string_of_id id ^ " : " ^ string_of_typschm typschm)); let env = Env.add_extern id exts env in let env = if is_cast then Env.add_cast id env else env in + let typq', typ' = expand_bind_synonyms ts_l env (typq, typ) in + (* !opt_expand_valspec controls whether the actual valspec in + the AST is expanded, the val_spec type stored in the + environment is always expanded and uses typq' and typ' *) let typq, typ = if !opt_expand_valspec then - expand_bind_synonyms ts_l env (typq, typ) + (typq', typ') else (typq, typ) in let vs = VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), ts_l), id, exts, is_cast) in - (vs, id, typq, typ, env) + (vs, id, typq', typ', env) in let eff = match typ with @@ -4422,6 +4442,7 @@ and check_def : 'a. Env.t -> 'a def -> (tannot def) list * Env.t = | DEF_reg_dec (DEC_aux (DEC_alias (id, aspec), (l, annot))) -> cd_err () | DEF_reg_dec (DEC_aux (DEC_typ_alias (typ, id, aspec), (l, tannot))) -> cd_err () | DEF_scattered sdef -> check_scattered env sdef + | DEF_measure (id, pat, exp) -> [check_termination_measure_decl env (id, pat, exp)], env and check : 'a. Env.t -> 'a defs -> tannot defs * Env.t = fun env (Defs defs) -> @@ -4432,6 +4453,14 @@ and check : 'a. Env.t -> 'a defs -> tannot defs * Env.t = let (Defs defs, env) = check env (Defs defs) in (Defs (def @ defs)), env +and check_with_envs : 'a. Env.t -> 'a def list -> (tannot def list * Env.t) list = + fun env defs -> + match defs with + | [] -> [] + | def :: defs -> + let def, env = check_def env def in + (def, env) :: check_with_envs env defs + let initial_env = Env.empty |> Env.add_prover prove diff --git a/src/type_check.mli b/src/type_check.mli index 501a0d7d..c17d5e0b 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -202,6 +202,8 @@ module Env : sig val pattern_completeness_ctx : t -> Pattern_completeness.ctx val builtin_typs : typquant Bindings.t + + val get_union_id : id -> t -> typquant * typ end (** Push all the type variables and constraints from a typquant into @@ -295,6 +297,8 @@ val infer_exp : Env.t -> unit exp -> tannot exp val infer_pat : Env.t -> unit pat -> tannot pat * Env.t * unit exp list +val infer_lexp : Env.t -> unit lexp -> tannot lexp + val check_case : Env.t -> typ -> unit pexp -> typ -> tannot pexp val check_fundef : Env.t -> 'a fundef -> tannot def list * Env.t @@ -413,5 +417,9 @@ Some invariants that will hold of a fully checked AST are: Type_error.check *) val check : Env.t -> 'a defs -> tannot defs * Env.t +(** The same as [check], but exposes the intermediate type-checking + environments so we don't have to always re-check the entire AST *) +val check_with_envs : Env.t -> 'a def list -> (tannot def list * Env.t) list + (** The initial type checking environment *) val initial_env : Env.t |
