diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/c_backend.ml | 185 | ||||
| -rw-r--r-- | src/constant_fold.ml | 101 | ||||
| -rw-r--r-- | src/gen_lib/sail2_operators_bitlists.lem | 68 | ||||
| -rw-r--r-- | src/gen_lib/sail2_operators_mwords.lem | 74 | ||||
| -rw-r--r-- | src/gen_lib/sail2_prompt.lem | 32 | ||||
| -rw-r--r-- | src/gen_lib/sail2_prompt_monad.lem | 1 | ||||
| -rw-r--r-- | src/gen_lib/sail2_state.lem | 32 | ||||
| -rw-r--r-- | src/gen_lib/sail2_state_monad.lem | 18 | ||||
| -rw-r--r-- | src/gen_lib/sail2_string.lem | 2 | ||||
| -rw-r--r-- | src/gen_lib/sail2_values.lem | 4 | ||||
| -rw-r--r-- | src/interpreter.ml | 2 | ||||
| -rw-r--r-- | src/isail.ml | 25 | ||||
| -rw-r--r-- | src/lem_interp/sail2_instr_kinds.lem | 18 | ||||
| -rw-r--r-- | src/ocaml_backend.ml | 15 | ||||
| -rw-r--r-- | src/pretty_print_coq.ml | 361 | ||||
| -rw-r--r-- | src/pretty_print_lem.ml | 16 | ||||
| -rw-r--r-- | src/process_file.ml | 4 | ||||
| -rw-r--r-- | src/rewrites.ml | 352 | ||||
| -rw-r--r-- | src/rewrites.mli | 7 | ||||
| -rw-r--r-- | src/sail.ml | 6 | ||||
| -rw-r--r-- | src/sail_lib.ml | 4 | ||||
| -rw-r--r-- | src/state.ml | 14 | ||||
| -rw-r--r-- | src/test/lib/run_test_interp.ml | 51 | ||||
| -rw-r--r-- | src/type_check.ml | 26 | ||||
| -rw-r--r-- | src/type_check.mli | 8 | ||||
| -rw-r--r-- | src/value.ml | 5 | ||||
| -rw-r--r-- | src/value2.lem | 24 |
27 files changed, 1036 insertions, 419 deletions
diff --git a/src/c_backend.ml b/src/c_backend.ml index 85c2eeb3..433e3d85 100644 --- a/src/c_backend.ml +++ b/src/c_backend.ml @@ -63,6 +63,7 @@ module Big_int = Nat_big_num let c_verbosity = ref 0 let opt_ddump_flow_graphs = ref false let opt_trace = ref false +let opt_static = ref false (* Optimization flags *) let optimize_primops = ref false @@ -222,7 +223,7 @@ let rec is_stack_ctyp ctyp = match ctyp with let is_stack_typ ctx typ = is_stack_ctyp (ctyp_of_typ ctx typ) let ctor_bindings = List.fold_left (fun map (id, ctyp) -> Bindings.add id ctyp map) Bindings.empty - + (**************************************************************************) (* 3. Optimization of primitives and literals *) (**************************************************************************) @@ -778,7 +779,7 @@ let rec compile_aval ctx = function in [idecl vector_ctyp gs; iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init") [(F_lit (V_int (Big_int.of_int len)), CT_int64)]] - @ List.concat (List.mapi aval_set avals), + @ List.concat (List.mapi aval_set (if direction then List.rev avals else avals)), (F_id gs, vector_ctyp), [iclear vector_ctyp gs] @@ -927,7 +928,10 @@ let label str = let pointer_assign ctyp1 ctyp2 = match ctyp1 with | CT_ref ctyp1 when ctyp_equal ctyp1 ctyp2 -> true - | CT_ref ctyp1 -> c_error "Incompatible type in pointer assignment" + | CT_ref ctyp1 -> + c_error (Printf.sprintf "Incompatible type in pointer assignment between %s and %s" + (string_of_ctyp ctyp1) + (string_of_ctyp ctyp2)) | _ -> false let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = @@ -1050,11 +1054,16 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = (* FIXME: AE_record_update could be AV_record_update - would reduce some copying. *) | AE_record_update (aval, fields, typ) -> let ctyp = ctyp_of_typ ctx typ in + let ctors = match ctyp with + | CT_struct (_, ctors) -> List.fold_left (fun m (k, v) -> Bindings.add k v m) Bindings.empty ctors + | _ -> c_error "Cannot perform record update for non-record type" + in let gs = gensym () in let compile_fields (id, aval) = let field_setup, cval, field_cleanup = compile_aval ctx aval in field_setup - @ [icopy (CL_field (gs, string_of_id id, cval_ctyp cval)) cval] + @ [icomment (string_of_ctyp ctyp)] + @ [icopy (CL_field (gs, string_of_id id, Bindings.find id ctors)) cval] @ field_cleanup in let setup, cval, cleanup = compile_aval ctx aval in @@ -1227,6 +1236,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = in let loop_start_label = label "for_start_" in + let loop_end_label = label "for_end_" in let body_setup, body_call, body_cleanup = compile_aexp ctx body in let body_gs = gensym () in @@ -1235,17 +1245,16 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = @ variable_init step_gs step_setup step_call step_cleanup @ [iblock ([idecl CT_int64 loop_var; icopy (CL_id (loop_var, CT_int64)) (F_id from_gs, CT_int64); - ilabel loop_start_label; idecl CT_unit body_gs; - iblock (body_setup + iblock ([ilabel loop_start_label] + @ [ijump (F_op (F_id loop_var, (if is_inc then ">" else "<"), F_id to_gs), CT_bool) loop_end_label] + @ body_setup @ [body_call (CL_id (body_gs, CT_unit))] @ body_cleanup - @ if is_inc then - [icopy (CL_id (loop_var, CT_int64)) (F_op (F_id loop_var, "+", F_id step_gs), CT_int64); - ijump (F_op (F_id loop_var, "<=", F_id to_gs), CT_bool) loop_start_label] - else - [icopy (CL_id (loop_var, CT_int64)) (F_op (F_id loop_var, "-", F_id step_gs), CT_int64); - ijump (F_op (F_id loop_var, ">=", F_id to_gs), CT_bool) loop_start_label])])], + @ [icopy (CL_id (loop_var, CT_int64)) + (F_op (F_id loop_var, (if is_inc then "+" else "-"), F_id step_gs), CT_int64)] + @ [igoto loop_start_label]); + ilabel loop_end_label])], (fun clexp -> icopy clexp unit_fragment), [] @@ -1905,7 +1914,7 @@ let rec sgen_ctyp_name = function | CT_string -> "sail_string" | CT_real -> "real" | CT_ref ctyp -> "ref_" ^ sgen_ctyp_name ctyp - + let sgen_cval_param (frag, ctyp) = match ctyp with | CT_bits direction -> @@ -1949,10 +1958,17 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = string (Printf.sprintf " %s = %s;" (sgen_clexp_pure clexp) (sgen_cval cval)) else string (Printf.sprintf " COPY(%s)(%s, %s);" (sgen_ctyp_name lctyp) (sgen_clexp clexp) (sgen_cval cval)) + else if pointer_assign lctyp rctyp then + let lctyp = match lctyp with + | CT_ref lctyp -> lctyp + | _ -> assert false + in + if is_stack_ctyp lctyp then + string (Printf.sprintf " *(%s) = %s;" (sgen_clexp_pure clexp) (sgen_cval cval)) + else + string (Printf.sprintf " COPY(%s)(*(%s), %s);" (sgen_ctyp_name lctyp) (sgen_clexp clexp) (sgen_cval cval)) else - if pointer_assign lctyp rctyp then - string (Printf.sprintf " %s = &%s;" (sgen_clexp_pure clexp) (sgen_cval cval)) - else if is_stack_ctyp lctyp then + if is_stack_ctyp lctyp then string (Printf.sprintf " %s = CONVERT_OF(%s, %s)(%s);" (sgen_clexp_pure clexp) (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_cval cval)) else @@ -2125,7 +2141,7 @@ let codegen_type_def ctx = function | CTD_enum (id, ((first_id :: _) as ids)) -> let codegen_eq = let name = sgen_id id in - string (Printf.sprintf "bool eq_%s(enum %s op1, enum %s op2) { return op1 == op2; }" name name name) + string (Printf.sprintf "static bool eq_%s(enum %s op1, enum %s op2) { return op1 == op2; }" name name name) in let codegen_undefined = let name = sgen_id id in @@ -2150,7 +2166,7 @@ let codegen_type_def ctx = function string (Printf.sprintf "COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) in let codegen_setter id ctors = - string (let n = sgen_id id in Printf.sprintf "void COPY(%s)(struct %s *rop, const struct %s op)" n n n) ^^ space + string (let n = sgen_id id in Printf.sprintf "static void COPY(%s)(struct %s *rop, const struct %s op)" n n n) ^^ space ^^ surround 2 0 lbrace (separate_map hardline codegen_set (Bindings.bindings ctors)) rbrace @@ -2162,16 +2178,23 @@ let codegen_type_def ctx = function else [] in let codegen_init f id ctors = - string (let n = sgen_id id in Printf.sprintf "void %s(%s)(struct %s *op)" f n n) ^^ space + string (let n = sgen_id id in Printf.sprintf "static void %s(%s)(struct %s *op)" f n n) ^^ space ^^ surround 2 0 lbrace (separate hardline (Bindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat)) rbrace in - (* let codegen_eq = - string (Printf.sprintf "bool eq_%s(struct %s op1, struct %s op2) { return true; }" (sgen_id id) (sgen_id id) (sgen_id id)) + let codegen_eq_test (id, ctyp) = + string (Printf.sprintf "EQUAL(%s)(op1.%s, op2.%s)" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) + in + string (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2)" (sgen_id id) (sgen_id id) (sgen_id id)) + ^^ space + ^^ surround 2 0 lbrace + (string "return" ^^ space + ^^ separate_map (string " && ") codegen_eq_test ctors + ^^ string ";") + rbrace in - *) (* Generate the struct and add the generated functions *) let codegen_ctor (id, ctyp) = string (sgen_ctyp ctyp) ^^ space ^^ codegen_id id @@ -2183,18 +2206,16 @@ let codegen_type_def ctx = function rbrace ^^ semi ^^ twice hardline ^^ codegen_setter id (ctor_bindings ctors) - ^^ if not (is_stack_ctyp struct_ctyp) then - twice hardline - ^^ codegen_init "CREATE" id (ctor_bindings ctors) - ^^ twice hardline - ^^ codegen_init "RECREATE" id (ctor_bindings ctors) - ^^ twice hardline - ^^ codegen_init "KILL" id (ctor_bindings ctors) - else empty - (* + ^^ (if not (is_stack_ctyp struct_ctyp) then + twice hardline + ^^ codegen_init "CREATE" id (ctor_bindings ctors) + ^^ twice hardline + ^^ codegen_init "RECREATE" id (ctor_bindings ctors) + ^^ twice hardline + ^^ codegen_init "KILL" id (ctor_bindings ctors) + else empty) ^^ twice hardline ^^ codegen_eq - *) | CTD_variant (id, tus) -> let codegen_tu (ctor_id, ctyp) = @@ -2215,7 +2236,7 @@ let codegen_type_def ctx = function let codegen_init = let n = sgen_id id in let ctor_id, ctyp = List.hd tus in - string (Printf.sprintf "void CREATE(%s)(struct %s *op)" n n) + string (Printf.sprintf "static void CREATE(%s)(struct %s *op)" n n) ^^ hardline ^^ surround 2 0 lbrace (string (Printf.sprintf "op->kind = Kind_%s;" (sgen_id ctor_id)) ^^ hardline @@ -2226,7 +2247,7 @@ let codegen_type_def ctx = function in let codegen_reinit = let n = sgen_id id in - string (Printf.sprintf "void RECREATE(%s)(struct %s *op) {}" n n) + string (Printf.sprintf "static void RECREATE(%s)(struct %s *op) {}" n n) in let clear_field v ctor_id ctyp = if is_stack_ctyp ctyp then @@ -2236,7 +2257,7 @@ let codegen_type_def ctx = function in let codegen_clear = let n = sgen_id id in - string (Printf.sprintf "void KILL(%s)(struct %s *op)" n n) ^^ hardline + string (Printf.sprintf "static void KILL(%s)(struct %s *op)" n n) ^^ hardline ^^ surround 2 0 lbrace (each_ctor "op->" (clear_field "op") tus ^^ semi) rbrace @@ -2260,7 +2281,7 @@ let codegen_type_def ctx = function ^^ separate hardline (List.mapi tuple_set ctyps) ^^ hardline | ctyp -> Printf.sprintf "%s op" (sgen_ctyp ctyp), empty in - string (Printf.sprintf "void %s(struct %s *rop, %s)" (sgen_id ctor_id) (sgen_id id) ctor_args) ^^ hardline + string (Printf.sprintf "static void %s(struct %s *rop, %s)" (sgen_id ctor_id) (sgen_id id) ctor_args) ^^ hardline ^^ surround 2 0 lbrace (tuple ^^ each_ctor "rop->" (clear_field "rop") tus ^^ hardline @@ -2281,7 +2302,7 @@ let codegen_type_def ctx = function string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) ^^ string (Printf.sprintf " COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id)) in - string (Printf.sprintf "void COPY(%s)(struct %s *rop, struct %s op)" n n n) ^^ hardline + string (Printf.sprintf "static void COPY(%s)(struct %s *rop, struct %s op)" n n n) ^^ hardline ^^ surround 2 0 lbrace (each_ctor "rop->" (clear_field "rop") tus ^^ semi ^^ hardline @@ -2290,6 +2311,21 @@ let codegen_type_def ctx = function ^^ each_ctor "op." set_field tus) rbrace in + let codegen_eq = + let codegen_eq_test ctor_id ctyp = + string (Printf.sprintf "return EQUAL(%s)(op1.%s, op2.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id)) + in + let rec codegen_eq_tests = function + | [] -> string "return false;" + | (ctor_id, ctyp) :: ctors -> + string (Printf.sprintf "if (op1.kind == Kind_%s && op2.kind == Kind_%s) " (sgen_id ctor_id) (sgen_id ctor_id)) ^^ lbrace ^^ hardline + ^^ jump 0 2 (codegen_eq_test ctor_id ctyp) + ^^ hardline ^^ rbrace ^^ string " else " ^^ codegen_eq_tests ctors + in + let n = sgen_id id in + string (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2) " n n n) + ^^ surround 2 0 lbrace (codegen_eq_tests tus) rbrace + in string (Printf.sprintf "// union %s" (string_of_id id)) ^^ hardline ^^ string "enum" ^^ space ^^ string ("kind_" ^ sgen_id id) ^^ space @@ -2317,6 +2353,8 @@ let codegen_type_def ctx = function ^^ twice hardline ^^ codegen_setter ^^ twice hardline + ^^ codegen_eq + ^^ twice hardline ^^ separate_map (twice hardline) codegen_ctor tus (* If this is the exception type, then we setup up some global variables to deal with exceptions. *) ^^ if string_of_id id = "exception" then @@ -2363,10 +2401,10 @@ let codegen_node id ctyp = ^^ string (Printf.sprintf "typedef struct node_%s *%s;" (sgen_id id) (sgen_id id)) let codegen_list_init id = - string (Printf.sprintf "void CREATE(%s)(%s *rop) { *rop = NULL; }" (sgen_id id) (sgen_id id)) + string (Printf.sprintf "static void CREATE(%s)(%s *rop) { *rop = NULL; }" (sgen_id id) (sgen_id id)) let codegen_list_clear id ctyp = - string (Printf.sprintf "void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id)) + string (Printf.sprintf "static void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id)) ^^ string (Printf.sprintf " if (*rop == NULL) return;") ^^ (if is_stack_ctyp ctyp then empty else string (Printf.sprintf " KILL(%s)(&(*rop)->hd);\n" (sgen_ctyp_name ctyp))) @@ -2375,7 +2413,7 @@ let codegen_list_clear id ctyp = ^^ string "}" let codegen_list_set id ctyp = - string (Printf.sprintf "void internal_set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) + string (Printf.sprintf "static void internal_set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) ^^ string " if (op == NULL) { *rop = NULL; return; };\n" ^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id)) ^^ (if is_stack_ctyp ctyp then @@ -2386,14 +2424,14 @@ let codegen_list_set id ctyp = ^^ string (Printf.sprintf " internal_set_%s(&(*rop)->tl, op->tl);\n" (sgen_id id)) ^^ string "}" ^^ twice hardline - ^^ string (Printf.sprintf "void COPY(%s)(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) + ^^ string (Printf.sprintf "static void COPY(%s)(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) ^^ string (Printf.sprintf " KILL(%s)(rop);\n" (sgen_id id)) ^^ string (Printf.sprintf " internal_set_%s(rop, op);\n" (sgen_id id)) ^^ string "}" let codegen_cons id ctyp = let cons_id = mk_id ("cons#" ^ string_of_ctyp ctyp) in - string (Printf.sprintf "void %s(%s *rop, const %s x, const %s xs) {\n" (sgen_id cons_id) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + string (Printf.sprintf "static void %s(%s *rop, const %s x, const %s xs) {\n" (sgen_id cons_id) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) ^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id)) ^^ (if is_stack_ctyp ctyp then string " (*rop)->hd = x;\n" @@ -2405,9 +2443,9 @@ let codegen_cons id ctyp = let codegen_pick id ctyp = if is_stack_ctyp ctyp then - string (Printf.sprintf "%s pick_%s(const %s xs) { return xs->hd; }" (sgen_ctyp ctyp) (sgen_ctyp_name ctyp) (sgen_id id)) + string (Printf.sprintf "static %s pick_%s(const %s xs) { return xs->hd; }" (sgen_ctyp ctyp) (sgen_ctyp_name ctyp) (sgen_id id)) else - string (Printf.sprintf "void pick_%s(%s *x, const %s xs) { COPY(%s)(x, xs->hd); }" (sgen_ctyp_name ctyp) (sgen_ctyp ctyp) (sgen_id id) (sgen_ctyp_name ctyp)) + string (Printf.sprintf "static void pick_%s(%s *x, const %s xs) { COPY(%s)(x, xs->hd); }" (sgen_ctyp_name ctyp) (sgen_ctyp ctyp) (sgen_id id) (sgen_ctyp_name ctyp)) let codegen_list ctx ctyp = let id = mk_id (string_of_ctyp (CT_list ctyp)) in @@ -2435,10 +2473,10 @@ let codegen_vector ctx (direction, ctyp) = ^^ string (Printf.sprintf "typedef struct %s %s;" (sgen_id id) (sgen_id id)) in let vector_init = - string (Printf.sprintf "void CREATE(%s)(%s *rop) {\n rop->len = 0;\n rop->data = NULL;\n}" (sgen_id id) (sgen_id id)) + string (Printf.sprintf "static void CREATE(%s)(%s *rop) {\n rop->len = 0;\n rop->data = NULL;\n}" (sgen_id id) (sgen_id id)) in let vector_set = - string (Printf.sprintf "void COPY(%s)(%s *rop, %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) + string (Printf.sprintf "static void COPY(%s)(%s *rop, %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id)) ^^ string (Printf.sprintf " KILL(%s)(rop);\n" (sgen_id id)) ^^ string " rop->len = op.len;\n" ^^ string (Printf.sprintf " rop->data = malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp)) @@ -2451,7 +2489,7 @@ let codegen_vector ctx (direction, ctyp) = ^^ string "}" in let vector_clear = - string (Printf.sprintf "void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id)) + string (Printf.sprintf "static void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id)) ^^ (if is_stack_ctyp ctyp then empty else string " for (int i = 0; i < (rop->len); i++) {\n" @@ -2461,7 +2499,7 @@ let codegen_vector ctx (direction, ctyp) = ^^ string "}" in let vector_update = - string (Printf.sprintf "void vector_update_%s(%s *rop, %s op, mpz_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + string (Printf.sprintf "static void vector_update_%s(%s *rop, %s op, mpz_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) ^^ string " int m = mpz_get_ui(n);\n" ^^ string " if (rop->data == op.data) {\n" ^^ string (if is_stack_ctyp ctyp then @@ -2478,7 +2516,7 @@ let codegen_vector ctx (direction, ctyp) = ^^ string "}" in let internal_vector_update = - string (Printf.sprintf "void internal_vector_update_%s(%s *rop, %s op, const int64_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + string (Printf.sprintf "static void internal_vector_update_%s(%s *rop, %s op, const int64_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) ^^ string (if is_stack_ctyp ctyp then " rop->data[n] = elem;\n" else @@ -2487,24 +2525,24 @@ let codegen_vector ctx (direction, ctyp) = in let vector_access = if is_stack_ctyp ctyp then - string (Printf.sprintf "%s vector_access_%s(%s op, mpz_t n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) + string (Printf.sprintf "static %s vector_access_%s(%s op, mpz_t n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) ^^ string " int m = mpz_get_ui(n);\n" ^^ string " return op.data[m];\n" ^^ string "}" else - string (Printf.sprintf "void vector_access_%s(%s *rop, %s op, mpz_t n) {\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + string (Printf.sprintf "static void vector_access_%s(%s *rop, %s op, mpz_t n) {\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) ^^ string " int m = mpz_get_ui(n);\n" ^^ string (Printf.sprintf " COPY(%s)(rop, op.data[m]);\n" (sgen_ctyp_name ctyp)) ^^ string "}" in let internal_vector_init = - string (Printf.sprintf "void internal_vector_init_%s(%s *rop, const int64_t len) {\n" (sgen_id id) (sgen_id id)) + string (Printf.sprintf "static void internal_vector_init_%s(%s *rop, const int64_t len) {\n" (sgen_id id) (sgen_id id)) ^^ string " rop->len = len;\n" ^^ string (Printf.sprintf " rop->data = malloc(len * sizeof(%s));\n" (sgen_ctyp ctyp)) ^^ string "}" in let vector_undefined = - string (Printf.sprintf "void undefined_vector_%s(%s *rop, mpz_t len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + string (Printf.sprintf "static void undefined_vector_%s(%s *rop, mpz_t len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) ^^ string (Printf.sprintf " rop->len = mpz_get_ui(len);\n") ^^ string (Printf.sprintf " rop->data = malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp)) ^^ string " for (int i = 0; i < (rop->len); i++) {\n" @@ -2549,12 +2587,13 @@ let codegen_def' ctx = function ^^ string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id)) | CDEF_spec (id, arg_ctyps, ret_ctyp) -> + let static = if !opt_static then "static " else "" in if Env.is_extern id ctx.tc_env "c" then empty else if is_stack_ctyp ret_ctyp then - string (Printf.sprintf "%s %s(%s);" (sgen_ctyp ret_ctyp) (sgen_id id) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) + string (Printf.sprintf "%s%s %s(%s);" static (sgen_ctyp ret_ctyp) (sgen_id id) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) else - string (Printf.sprintf "void %s(%s *rop, %s);" (sgen_id id) (sgen_ctyp ret_ctyp) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) + string (Printf.sprintf "%svoid %s(%s *rop, %s);" static (sgen_id id) (sgen_ctyp ret_ctyp) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) | CDEF_fundef (id, ret_arg, args, instrs) as def -> if !opt_ddump_flow_graphs then make_dot id (instrs_graph instrs) else (); @@ -2578,10 +2617,12 @@ let codegen_def' ctx = function match ret_arg with | None -> assert (is_stack_ctyp ret_ctyp); - string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline + (if !opt_static then string "static " else empty) + ^^ string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline | Some gs -> assert (not (is_stack_ctyp ret_ctyp)); - string "void" ^^ space ^^ codegen_id id + (if !opt_static then string "static " else empty) + ^^ string "void" ^^ space ^^ codegen_id id ^^ parens (string (sgen_ctyp ret_ctyp ^ " *" ^ sgen_id gs ^ ", ") ^^ string args) ^^ hardline in @@ -2594,7 +2635,8 @@ let codegen_def' ctx = function codegen_type_def ctx ctype_def | CDEF_startup (id, instrs) -> - let startup_header = string (Printf.sprintf "void startup_%s(void)" (sgen_id id)) in + let static = if !opt_static then "static " else "" in + let startup_header = string (Printf.sprintf "%svoid startup_%s(void)" static (sgen_id id)) in separate_map hardline codegen_decl instrs ^^ twice hardline ^^ startup_header ^^ hardline @@ -2603,7 +2645,8 @@ let codegen_def' ctx = function ^^ string "}" | CDEF_finish (id, instrs) -> - let finish_header = string (Printf.sprintf "void finish_%s(void)" (sgen_id id)) in + let static = if !opt_static then "static " else "" in + let finish_header = string (Printf.sprintf "%svoid finish_%s(void)" static (sgen_id id)) in separate_map hardline codegen_decl (List.filter is_decl instrs) ^^ twice hardline ^^ finish_header ^^ hardline @@ -2620,12 +2663,12 @@ let codegen_def' ctx = function List.concat (List.map (fun (id, ctyp) -> [iclear ctyp id]) bindings) in separate_map hardline (fun (id, ctyp) -> string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id))) bindings - ^^ hardline ^^ string (Printf.sprintf "void create_letbind_%d(void) " number) + ^^ hardline ^^ string (Printf.sprintf "static void create_letbind_%d(void) " number) ^^ string "{" ^^ jump 0 2 (separate_map hardline codegen_alloc setup) ^^ hardline ^^ jump 0 2 (separate_map hardline (codegen_instr (mk_id "let") { ctx with no_raw = true }) instrs) ^^ hardline ^^ string "}" - ^^ hardline ^^ string (Printf.sprintf "void kill_letbind_%d(void) " number) + ^^ hardline ^^ string (Printf.sprintf "static void kill_letbind_%d(void) " number) ^^ string "{" ^^ jump 0 2 (separate_map hardline (codegen_instr (mk_id "let") ctx) cleanup) ^^ hardline ^^ string "}" @@ -2742,6 +2785,26 @@ let rec get_recursive_functions (Defs defs) = match defs with | DEF_internal_mutrec fundefs :: defs -> IdSet.union (List.map id_of_fundef fundefs |> IdSet.of_list) (get_recursive_functions (Defs defs)) + + | (DEF_fundef fdef as def) :: defs -> + let open Rewriter in + let ids = ref IdSet.empty in + let collect_funcalls e_aux annot = + match e_aux with + | E_app (id, args) -> (ids := IdSet.add id !ids; E_aux (e_aux, annot)) + | _ -> E_aux (e_aux, annot) + in + let map_exp = { + id_exp_alg with + e_aux = (fun (e_aux, annot) -> collect_funcalls e_aux annot) + } in + let map_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp map_exp) } in + let _ = rewrite_def map_defs def in + if IdSet.mem (id_of_fundef fdef) !ids then + IdSet.add (id_of_fundef fdef) (get_recursive_functions (Defs defs)) + else + get_recursive_functions (Defs defs) + | _ :: defs -> get_recursive_functions (Defs defs) | [] -> IdSet.empty diff --git a/src/constant_fold.ml b/src/constant_fold.ml index 7a35226e..45d3efe0 100644 --- a/src/constant_fold.ml +++ b/src/constant_fold.ml @@ -59,10 +59,25 @@ module StringMap = Map.Make(String);; false = no folding, true = perform constant folding. *) let optimize_constant_fold = ref false -let exp_of_value = +let rec fexp_of_ctor (field, value) = + FE_aux (FE_Fexp (mk_id field, exp_of_value value), no_annot) + +and exp_of_value = let open Value in function | V_int n -> mk_lit_exp (L_num n) + | V_bit Sail_lib.B0 -> mk_lit_exp L_zero + | V_bit Sail_lib.B1 -> mk_lit_exp L_one + | V_bool true -> mk_lit_exp L_true + | V_bool false -> mk_lit_exp L_false + | V_string str -> mk_lit_exp (L_string str) + | V_record ctors -> + mk_exp (E_record (FES_aux (FES_Fexps (List.map fexp_of_ctor (StringMap.bindings ctors), false), no_annot))) + | V_vector vs -> + mk_exp (E_vector (List.map exp_of_value vs)) + | V_tuple vs -> + mk_exp (E_tuple (List.map exp_of_value vs)) + | V_unit -> mk_lit_exp L_unit | _ -> failwith "No expression for value" (* We want to avoid evaluating things like print statements at compile @@ -85,15 +100,23 @@ let safe_primops = "Elf_loader.elf_tohost" ] -let is_literal = function - | E_aux (E_lit _, _) -> true +let rec is_constant (E_aux (e_aux, _)) = + match e_aux with + | E_lit _ -> true + | E_vector exps -> List.for_all is_constant exps + | E_record (FES_aux (FES_Fexps (fexps, _), _)) -> List.for_all is_constant_fexp fexps + | E_cast (_, exp) -> is_constant exp + | E_tuple exps -> List.for_all is_constant exps | _ -> false +and is_constant_fexp (FE_aux (FE_Fexp (_, exp), _)) = is_constant exp + (* Wrapper around interpreter that repeatedly steps until done. *) let rec run ast frame = match frame with | Interpreter.Done (state, v) -> v - | Interpreter.Step _ -> + | Interpreter.Step (lazy_str, _, _, _) -> + prerr_endline (Lazy.force lazy_str); run ast (Interpreter.eval_frame ast frame) | Interpreter.Break frame -> run ast (Interpreter.eval_frame ast frame) @@ -115,35 +138,57 @@ let rec run ast frame = - Throws an exception that isn't caught. *) -let rewrite_constant_function_calls' ast = +let rec rewrite_constant_function_calls' ast = + let rewrite_count = ref 0 in + let ok () = incr rewrite_count in + let not_ok () = decr rewrite_count in + let lstate, gstate = Interpreter.initial_state ast safe_primops in let gstate = { gstate with Interpreter.allow_registers = false } in + let evaluate e_aux annot = + let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in + try + begin + let v = run ast (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in + let exp = exp_of_value v in + try (ok (); Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)) with + | Type_error (l, err) -> + (* A type error here would be unexpected, so don't ignore it! *) + Util.warn ("Type error when folding constants in " + ^ string_of_exp (E_aux (e_aux, annot)) + ^ "\n" ^ Type_error.string_of_type_error err); + not_ok (); + E_aux (e_aux, annot) + end + with + (* Otherwise if anything goes wrong when trying to constant + fold, just continue without optimising. *) + | _ -> E_aux (e_aux, annot) + in + let rw_funcall e_aux annot = match e_aux with - | E_app (id, args) when List.for_all is_literal args -> - begin - let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in - try - begin - let v = run ast (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in - let exp = exp_of_value v in - try Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot) with - | Type_error (l, err) -> - (* A type error here would be unexpected, so don't ignore it! *) - Util.warn ("Type error when folding constants in " - ^ string_of_exp (E_aux (e_aux, annot)) - ^ "\n" ^ Type_error.string_of_type_error err); - E_aux (e_aux, annot) - end - with - (* Otherwise if anything goes wrong when trying to constant - fold, just continue without optimising. *) - | _ -> E_aux (e_aux, annot) - end + | E_app (id, args) when List.for_all is_constant args -> + evaluate e_aux annot + + | E_field (exp, id) when is_constant exp -> + evaluate e_aux annot + + | E_if (E_aux (E_lit (L_aux (L_true, _)), _), then_exp, _) -> ok (); then_exp + | E_if (E_aux (E_lit (L_aux (L_false, _)), _), _, else_exp) -> ok (); else_exp + + | E_let (LB_aux (LB_val (P_aux (P_id id, _), bind), _), exp) when is_constant bind -> + ok (); + subst id bind exp + | E_let (LB_aux (LB_val (P_aux (P_typ (typ, P_aux (P_id id, _)), annot), bind), _), exp) + when is_constant bind -> + ok (); + subst id (E_aux (E_cast (typ, bind), annot)) exp + | _ -> E_aux (e_aux, annot) in let rw_exp = { @@ -151,7 +196,11 @@ let rewrite_constant_function_calls' ast = e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot) } in let rw_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp) } in - rewrite_defs_base rw_defs ast + let ast = rewrite_defs_base rw_defs ast in + (* We keep iterating until we have no more re-writes to do *) + if !rewrite_count > 0 + then rewrite_constant_function_calls' ast + else ast let rewrite_constant_function_calls ast = if !optimize_constant_fold then diff --git a/src/gen_lib/sail2_operators_bitlists.lem b/src/gen_lib/sail2_operators_bitlists.lem index 3f8b7510..186e0a09 100644 --- a/src/gen_lib/sail2_operators_bitlists.lem +++ b/src/gen_lib/sail2_operators_bitlists.lem @@ -10,16 +10,16 @@ open import Sail2_prompt val uint_maybe : list bitU -> maybe integer let uint_maybe v = unsigned v let uint_fail v = maybe_fail "uint" (unsigned v) -let uint_oracle v = - bools_of_bits_oracle v >>= (fun bs -> +let uint_nondet v = + bools_of_bits_nondet v >>= (fun bs -> return (int_of_bools false bs)) let uint v = maybe_failwith (uint_maybe v) val sint_maybe : list bitU -> maybe integer let sint_maybe v = signed v let sint_fail v = maybe_fail "sint" (signed v) -let sint_oracle v = - bools_of_bits_oracle v >>= (fun bs -> +let sint_nondet v = + bools_of_bits_nondet v >>= (fun bs -> return (int_of_bools true bs)) let sint v = maybe_failwith (sint_maybe v) @@ -43,13 +43,13 @@ let vector_truncate bs len = extz_bv len bs val vec_of_bits_maybe : list bitU -> maybe (list bitU) val vec_of_bits_fail : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e -val vec_of_bits_oracle : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e +val vec_of_bits_nondet : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e val vec_of_bits_failwith : list bitU -> list bitU val vec_of_bits : list bitU -> list bitU let inline vec_of_bits bits = bits let inline vec_of_bits_maybe bits = Just bits let inline vec_of_bits_fail bits = return bits -let inline vec_of_bits_oracle bits = return bits +let inline vec_of_bits_nondet bits = return bits let inline vec_of_bits_failwith bits = bits val access_vec_inc : list bitU -> integer -> bitU @@ -62,13 +62,13 @@ val update_vec_inc : list bitU -> integer -> bitU -> list bitU let update_vec_inc = update_bv_inc let update_vec_inc_maybe v i b = Just (update_vec_inc v i b) let update_vec_inc_fail v i b = return (update_vec_inc v i b) -let update_vec_inc_oracle v i b = return (update_vec_inc v i b) +let update_vec_inc_nondet v i b = return (update_vec_inc v i b) val update_vec_dec : list bitU -> integer -> bitU -> list bitU let update_vec_dec = update_bv_dec let update_vec_dec_maybe v i b = Just (update_vec_dec v i b) let update_vec_dec_fail v i b = return (update_vec_dec v i b) -let update_vec_dec_oracle v i b = return (update_vec_dec v i b) +let update_vec_dec_nondet v i b = return (update_vec_dec v i b) val subrange_vec_inc : list bitU -> integer -> integer -> list bitU let subrange_vec_inc = subrange_bv_inc @@ -89,19 +89,19 @@ val cons_vec : bitU -> list bitU -> list bitU let cons_vec = cons_bv let cons_vec_maybe b v = Just (cons_vec b v) let cons_vec_fail b v = return (cons_vec b v) -let cons_vec_oracle b v = return (cons_vec b v) +let cons_vec_nondet b v = return (cons_vec b v) val cast_unit_vec : bitU -> list bitU let cast_unit_vec = cast_unit_bv let cast_unit_vec_maybe b = Just (cast_unit_vec b) let cast_unit_vec_fail b = return (cast_unit_vec b) -let cast_unit_vec_oracle b = return (cast_unit_vec b) +let cast_unit_vec_nondet b = return (cast_unit_vec b) val vec_of_bit : integer -> bitU -> list bitU let vec_of_bit = bv_of_bit let vec_of_bit_maybe len b = Just (vec_of_bit len b) let vec_of_bit_fail len b = return (vec_of_bit len b) -let vec_of_bit_oracle len b = return (vec_of_bit len b) +let vec_of_bit_nondet len b = return (vec_of_bit len b) val msb : list bitU -> bitU let msb = most_significant @@ -109,7 +109,7 @@ let msb = most_significant val int_of_vec_maybe : bool -> list bitU -> maybe integer let int_of_vec_maybe = int_of_bv let int_of_vec_fail sign v = maybe_fail "int_of_vec" (int_of_vec_maybe sign v) -let int_of_vec_oracle sign v = bools_of_bits_oracle v >>= (fun v -> return (int_of_bools sign v)) +let int_of_vec_nondet sign v = bools_of_bits_nondet v >>= (fun v -> return (int_of_bools sign v)) let int_of_vec sign v = maybe_failwith (int_of_vec_maybe sign v) val string_of_bits : list bitU -> string @@ -146,30 +146,18 @@ let mult_vec = arith_op_double_bl integerMult false let mults_vec = arith_op_double_bl integerMult true val add_vec_int : list bitU -> integer -> list bitU -val adds_vec_int : list bitU -> integer -> list bitU val sub_vec_int : list bitU -> integer -> list bitU -val subs_vec_int : list bitU -> integer -> list bitU val mult_vec_int : list bitU -> integer -> list bitU -val mults_vec_int : list bitU -> integer -> list bitU let add_vec_int l r = arith_op_bv_int integerAdd false l r -let adds_vec_int l r = arith_op_bv_int integerAdd true l r let sub_vec_int l r = arith_op_bv_int integerMinus false l r -let subs_vec_int l r = arith_op_bv_int integerMinus true l r let mult_vec_int l r = arith_op_double_bl integerMult false l (of_int (length l) r) -let mults_vec_int l r = arith_op_double_bl integerMult true l (of_int (length l) r) val add_int_vec : integer -> list bitU -> list bitU -val adds_int_vec : integer -> list bitU -> list bitU val sub_int_vec : integer -> list bitU -> list bitU -val subs_int_vec : integer -> list bitU -> list bitU val mult_int_vec : integer -> list bitU -> list bitU -val mults_int_vec : integer -> list bitU -> list bitU let add_int_vec l r = arith_op_int_bv integerAdd false l r -let adds_int_vec l r = arith_op_int_bv integerAdd true l r let sub_int_vec l r = arith_op_int_bv integerMinus false l r -let subs_int_vec l r = arith_op_int_bv integerMinus true l r let mult_int_vec l r = arith_op_double_bl integerMult false (of_int (length r) l) r -let mults_int_vec l r = arith_op_double_bl integerMult true (of_int (length r) l) r val add_vec_bit : list bitU -> bitU -> list bitU val adds_vec_bit : list bitU -> bitU -> list bitU @@ -179,25 +167,25 @@ val subs_vec_bit : list bitU -> bitU -> list bitU let add_vec_bool l r = arith_op_bv_bool integerAdd false l r let add_vec_bit_maybe l r = arith_op_bv_bit integerAdd false l r let add_vec_bit_fail l r = maybe_fail "add_vec_bit" (add_vec_bit_maybe l r) -let add_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (add_vec_bool l r)) +let add_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (add_vec_bool l r)) let add_vec_bit l r = fromMaybe (repeat [BU] (length l)) (add_vec_bit_maybe l r) let adds_vec_bool l r = arith_op_bv_bool integerAdd true l r let adds_vec_bit_maybe l r = arith_op_bv_bit integerAdd true l r let adds_vec_bit_fail l r = maybe_fail "adds_vec_bit" (adds_vec_bit_maybe l r) -let adds_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (adds_vec_bool l r)) +let adds_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (adds_vec_bool l r)) let adds_vec_bit l r = fromMaybe (repeat [BU] (length l)) (adds_vec_bit_maybe l r) let sub_vec_bool l r = arith_op_bv_bool integerMinus false l r let sub_vec_bit_maybe l r = arith_op_bv_bit integerMinus false l r let sub_vec_bit_fail l r = maybe_fail "sub_vec_bit" (sub_vec_bit_maybe l r) -let sub_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (sub_vec_bool l r)) +let sub_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (sub_vec_bool l r)) let sub_vec_bit l r = fromMaybe (repeat [BU] (length l)) (sub_vec_bit_maybe l r) let subs_vec_bool l r = arith_op_bv_bool integerMinus true l r let subs_vec_bit_maybe l r = arith_op_bv_bit integerMinus true l r let subs_vec_bit_fail l r = maybe_fail "sub_vec_bit" (subs_vec_bit_maybe l r) -let subs_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (subs_vec_bool l r)) +let subs_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (subs_vec_bool l r)) let subs_vec_bit l r = fromMaybe (repeat [BU] (length l)) (subs_vec_bit_maybe l r) (*val add_overflow_vec : list bitU -> list bitU -> (list bitU * bitU * bitU) @@ -236,47 +224,47 @@ let rotr = rotr_bv val mod_vec : list bitU -> list bitU -> list bitU val mod_vec_maybe : list bitU -> list bitU -> maybe (list bitU) val mod_vec_fail : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e -val mod_vec_oracle : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e +val mod_vec_nondet : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e let mod_vec l r = fromMaybe (repeat [BU] (length l)) (mod_bv l r) let mod_vec_maybe l r = mod_bv l r let mod_vec_fail l r = maybe_fail "mod_vec" (mod_bv l r) -let mod_vec_oracle l r = of_bits_oracle (mod_vec l r) +let mod_vec_nondet l r = of_bits_nondet (mod_vec l r) val quot_vec : list bitU -> list bitU -> list bitU val quot_vec_maybe : list bitU -> list bitU -> maybe (list bitU) val quot_vec_fail : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e -val quot_vec_oracle : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e +val quot_vec_nondet : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e let quot_vec l r = fromMaybe (repeat [BU] (length l)) (quot_bv l r) let quot_vec_maybe l r = quot_bv l r let quot_vec_fail l r = maybe_fail "quot_vec" (quot_bv l r) -let quot_vec_oracle l r = of_bits_oracle (quot_vec l r) +let quot_vec_nondet l r = of_bits_nondet (quot_vec l r) val quots_vec : list bitU -> list bitU -> list bitU val quots_vec_maybe : list bitU -> list bitU -> maybe (list bitU) val quots_vec_fail : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e -val quots_vec_oracle : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e +val quots_vec_nondet : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e let quots_vec l r = fromMaybe (repeat [BU] (length l)) (quots_bv l r) let quots_vec_maybe l r = quots_bv l r let quots_vec_fail l r = maybe_fail "quots_vec" (quots_bv l r) -let quots_vec_oracle l r = of_bits_oracle (quots_vec l r) +let quots_vec_nondet l r = of_bits_nondet (quots_vec l r) val mod_vec_int : list bitU -> integer -> list bitU val mod_vec_int_maybe : list bitU -> integer -> maybe (list bitU) val mod_vec_int_fail : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e -val mod_vec_int_oracle : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e +val mod_vec_int_nondet : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e let mod_vec_int l r = fromMaybe (repeat [BU] (length l)) (mod_bv_int l r) let mod_vec_int_maybe l r = mod_bv_int l r let mod_vec_int_fail l r = maybe_fail "mod_vec_int" (mod_bv_int l r) -let mod_vec_int_oracle l r = of_bits_oracle (mod_vec_int l r) +let mod_vec_int_nondet l r = of_bits_nondet (mod_vec_int l r) val quot_vec_int : list bitU -> integer -> list bitU val quot_vec_int_maybe : list bitU -> integer -> maybe (list bitU) val quot_vec_int_fail : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e -val quot_vec_int_oracle : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e +val quot_vec_int_nondet : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e let quot_vec_int l r = fromMaybe (repeat [BU] (length l)) (quot_bv_int l r) let quot_vec_int_maybe l r = quot_bv_int l r let quot_vec_int_fail l r = maybe_fail "quot_vec_int" (quot_bv_int l r) -let quot_vec_int_oracle l r = of_bits_oracle (quot_vec_int l r) +let quot_vec_int_nondet l r = of_bits_nondet (quot_vec_int l r) val replicate_bits : list bitU -> integer -> list bitU let replicate_bits = replicate_bits_bv @@ -285,8 +273,8 @@ val duplicate : bitU -> integer -> list bitU let duplicate = duplicate_bit_bv let duplicate_maybe b n = Just (duplicate b n) let duplicate_fail b n = return (duplicate b n) -let duplicate_oracle b n = - bool_of_bitU_oracle b >>= (fun b -> +let duplicate_nondet b n = + bool_of_bitU_nondet b >>= (fun b -> return (duplicate (bitU_of_bool b) n)) val reverse_endianness : list bitU -> list bitU diff --git a/src/gen_lib/sail2_operators_mwords.lem b/src/gen_lib/sail2_operators_mwords.lem index 1e4d63ba..a7fb7c50 100644 --- a/src/gen_lib/sail2_operators_mwords.lem +++ b/src/gen_lib/sail2_operators_mwords.lem @@ -10,21 +10,21 @@ open import Sail2_prompt let inline uint v = unsignedIntegerFromWord v let uint_maybe v = Just (uint v) let uint_fail v = return (uint v) -let uint_oracle v = return (uint v) +let uint_nondet v = return (uint v) let inline sint v = signedIntegerFromWord v let sint_maybe v = Just (sint v) let sint_fail v = return (sint v) -let sint_oracle v = return (sint v) +let sint_nondet v = return (sint v) val vec_of_bits_maybe : forall 'a. Size 'a => list bitU -> maybe (mword 'a) val vec_of_bits_fail : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e -val vec_of_bits_oracle : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e +val vec_of_bits_nondet : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e val vec_of_bits_failwith : forall 'a. Size 'a => list bitU -> mword 'a val vec_of_bits : forall 'a. Size 'a => list bitU -> mword 'a let vec_of_bits_maybe bits = of_bits bits let vec_of_bits_fail bits = of_bits_fail bits -let vec_of_bits_oracle bits = of_bits_oracle bits +let vec_of_bits_nondet bits = of_bits_nondet bits let vec_of_bits_failwith bits = of_bits_failwith bits let vec_of_bits bits = of_bits_failwith bits @@ -38,8 +38,8 @@ let update_vec_dec_maybe w i b = update_mword_dec w i b let update_vec_dec_fail w i b = bool_of_bitU_fail b >>= (fun b -> return (update_mword_bool_dec w i b)) -let update_vec_dec_oracle w i b = - bool_of_bitU_oracle b >>= (fun b -> +let update_vec_dec_nondet w i b = + bool_of_bitU_nondet b >>= (fun b -> return (update_mword_bool_dec w i b)) let update_vec_dec w i b = maybe_failwith (update_vec_dec_maybe w i b) @@ -47,8 +47,8 @@ let update_vec_inc_maybe w i b = update_mword_inc w i b let update_vec_inc_fail w i b = bool_of_bitU_fail b >>= (fun b -> return (update_mword_bool_inc w i b)) -let update_vec_inc_oracle w i b = - bool_of_bitU_oracle b >>= (fun b -> +let update_vec_inc_nondet w i b = + bool_of_bitU_nondet b >>= (fun b -> return (update_mword_bool_inc w i b)) let update_vec_inc w i b = maybe_failwith (update_vec_inc_maybe w i b) @@ -89,21 +89,21 @@ val cons_vec_bool : forall 'a 'b 'c. Size 'a, Size 'b => bool -> mword 'a -> mwo let cons_vec_bool b w = wordFromBitlist (b :: bitlistFromWord w) let cons_vec_maybe b w = Maybe.map (fun b -> cons_vec_bool b w) (bool_of_bitU b) let cons_vec_fail b w = bool_of_bitU_fail b >>= (fun b -> return (cons_vec_bool b w)) -let cons_vec_oracle b w = bool_of_bitU_oracle b >>= (fun b -> return (cons_vec_bool b w)) +let cons_vec_nondet b w = bool_of_bitU_nondet b >>= (fun b -> return (cons_vec_bool b w)) let cons_vec b w = maybe_failwith (cons_vec_maybe b w) val vec_of_bool : forall 'a. Size 'a => integer -> bool -> mword 'a let vec_of_bool _ b = wordFromBitlist [b] let vec_of_bit_maybe len b = Maybe.map (vec_of_bool len) (bool_of_bitU b) let vec_of_bit_fail len b = bool_of_bitU_fail b >>= (fun b -> return (vec_of_bool len b)) -let vec_of_bit_oracle len b = bool_of_bitU_oracle b >>= (fun b -> return (vec_of_bool len b)) +let vec_of_bit_nondet len b = bool_of_bitU_nondet b >>= (fun b -> return (vec_of_bool len b)) let vec_of_bit len b = maybe_failwith (vec_of_bit_maybe len b) val cast_bool_vec : bool -> mword ty1 let cast_bool_vec b = vec_of_bool 1 b let cast_unit_vec_maybe b = vec_of_bit_maybe 1 b let cast_unit_vec_fail b = bool_of_bitU_fail b >>= (fun b -> return (cast_bool_vec b)) -let cast_unit_vec_oracle b = bool_of_bitU_oracle b >>= (fun b -> return (cast_bool_vec b)) +let cast_unit_vec_nondet b = bool_of_bitU_nondet b >>= (fun b -> return (cast_bool_vec b)) let cast_unit_vec b = maybe_failwith (cast_unit_vec_maybe b) val msb : forall 'a. Size 'a => mword 'a -> bitU @@ -143,30 +143,18 @@ let mult_vec l r = arith_op_bv integerMult false (zeroExtend l : mword 'b) (ze let mults_vec l r = arith_op_bv integerMult true (signExtend l : mword 'b) (signExtend r : mword 'b) val add_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a -val adds_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a val sub_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a -val subs_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a val mult_vec_int : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b -val mults_vec_int : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b let add_vec_int l r = arith_op_bv_int integerAdd false l r -let adds_vec_int l r = arith_op_bv_int integerAdd true l r let sub_vec_int l r = arith_op_bv_int integerMinus false l r -let subs_vec_int l r = arith_op_bv_int integerMinus true l r let mult_vec_int l r = arith_op_bv_int integerMult false (zeroExtend l : mword 'b) r -let mults_vec_int l r = arith_op_bv_int integerMult true (signExtend l : mword 'b) r val add_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a -val adds_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a val sub_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a -val subs_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a val mult_int_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b -val mults_int_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b let add_int_vec l r = arith_op_int_bv integerAdd false l r -let adds_int_vec l r = arith_op_int_bv integerAdd true l r let sub_int_vec l r = arith_op_int_bv integerMinus false l r -let subs_int_vec l r = arith_op_int_bv integerMinus true l r let mult_int_vec l r = arith_op_int_bv integerMult false l (zeroExtend r : mword 'b) -let mults_int_vec l r = arith_op_int_bv integerMult true l (signExtend r : mword 'b) val add_vec_bool : forall 'a. Size 'a => mword 'a -> bool -> mword 'a val adds_vec_bool : forall 'a. Size 'a => mword 'a -> bool -> mword 'a @@ -176,25 +164,25 @@ val subs_vec_bool : forall 'a. Size 'a => mword 'a -> bool -> mword 'a let add_vec_bool l r = arith_op_bv_bool integerAdd false l r let add_vec_bit_maybe l r = Maybe.map (add_vec_bool l) (bool_of_bitU r) let add_vec_bit_fail l r = bool_of_bitU_fail r >>= (fun r -> return (add_vec_bool l r)) -let add_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (add_vec_bool l r)) +let add_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (add_vec_bool l r)) let add_vec_bit l r = maybe_failwith (add_vec_bit_maybe l r) let adds_vec_bool l r = arith_op_bv_bool integerAdd true l r let adds_vec_bit_maybe l r = Maybe.map (adds_vec_bool l) (bool_of_bitU r) let adds_vec_bit_fail l r = bool_of_bitU_fail r >>= (fun r -> return (adds_vec_bool l r)) -let adds_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (adds_vec_bool l r)) +let adds_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (adds_vec_bool l r)) let adds_vec_bit l r = maybe_failwith (adds_vec_bit_maybe l r) let sub_vec_bool l r = arith_op_bv_bool integerMinus false l r let sub_vec_bit_maybe l r = Maybe.map (sub_vec_bool l) (bool_of_bitU r) let sub_vec_bit_fail l r = bool_of_bitU_fail r >>= (fun r -> return (sub_vec_bool l r)) -let sub_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (sub_vec_bool l r)) +let sub_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (sub_vec_bool l r)) let sub_vec_bit l r = maybe_failwith (sub_vec_bit_maybe l r) let subs_vec_bool l r = arith_op_bv_bool integerMinus true l r let subs_vec_bit_maybe l r = Maybe.map (subs_vec_bool l) (bool_of_bitU r) let subs_vec_bit_fail l r = bool_of_bitU_fail r >>= (fun r -> return (subs_vec_bool l r)) -let subs_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (subs_vec_bool l r)) +let subs_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (subs_vec_bool l r)) let subs_vec_bit l r = maybe_failwith (subs_vec_bit_maybe l r) (* TODO @@ -238,66 +226,66 @@ let rotr = rotr_mword val mod_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a val mod_vec_maybe : forall 'a. Size 'a => mword 'a -> mword 'a -> maybe (mword 'a) val mod_vec_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e -val mod_vec_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e +val mod_vec_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e let mod_vec l r = mod_mword l r let mod_vec_maybe l r = mod_bv l r let mod_vec_fail l r = maybe_fail "mod_vec" (mod_bv l r) -let mod_vec_oracle l r = +let mod_vec_nondet l r = match (mod_bv l r) with | Just w -> return w - | Nothing -> mword_oracle () + | Nothing -> mword_nondet () end val quot_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a val quot_vec_maybe : forall 'a. Size 'a => mword 'a -> mword 'a -> maybe (mword 'a) val quot_vec_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e -val quot_vec_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e +val quot_vec_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e let quot_vec l r = quot_mword l r let quot_vec_maybe l r = quot_bv l r let quot_vec_fail l r = maybe_fail "quot_vec" (quot_bv l r) -let quot_vec_oracle l r = +let quot_vec_nondet l r = match (quot_bv l r) with | Just w -> return w - | Nothing -> mword_oracle () + | Nothing -> mword_nondet () end val quots_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a val quots_vec_maybe : forall 'a. Size 'a => mword 'a -> mword 'a -> maybe (mword 'a) val quots_vec_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e -val quots_vec_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e +val quots_vec_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e let quots_vec l r = quots_mword l r let quots_vec_maybe l r = quots_bv l r let quots_vec_fail l r = maybe_fail "quots_vec" (quots_bv l r) -let quots_vec_oracle l r = +let quots_vec_nondet l r = match (quots_bv l r) with | Just w -> return w - | Nothing -> mword_oracle () + | Nothing -> mword_nondet () end val mod_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a val mod_vec_int_maybe : forall 'a. Size 'a => mword 'a -> integer -> maybe (mword 'a) val mod_vec_int_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e -val mod_vec_int_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e +val mod_vec_int_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e let mod_vec_int l r = mod_mword_int l r let mod_vec_int_maybe l r = mod_bv_int l r let mod_vec_int_fail l r = maybe_fail "mod_vec_int" (mod_bv_int l r) -let mod_vec_int_oracle l r = +let mod_vec_int_nondet l r = match (mod_bv_int l r) with | Just w -> return w - | Nothing -> mword_oracle () + | Nothing -> mword_nondet () end val quot_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a val quot_vec_int_maybe : forall 'a. Size 'a => mword 'a -> integer -> maybe (mword 'a) val quot_vec_int_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e -val quot_vec_int_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e +val quot_vec_int_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e let quot_vec_int l r = quot_mword_int l r let quot_vec_int_maybe l r = quot_bv_int l r let quot_vec_int_fail l r = maybe_fail "quot_vec_int" (quot_bv_int l r) -let quot_vec_int_oracle l r = +let quot_vec_int_nondet l r = match (quot_bv_int l r) with | Just w -> return w - | Nothing -> mword_oracle () + | Nothing -> mword_nondet () end val replicate_bits : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b @@ -307,7 +295,7 @@ val duplicate_bool : forall 'a. Size 'a => bool -> integer -> mword 'a let duplicate_bool b n = wordFromBitlist (repeat [b] n) let duplicate_maybe b n = Maybe.map (fun b -> duplicate_bool b n) (bool_of_bitU b) let duplicate_fail b n = bool_of_bitU_fail b >>= (fun b -> return (duplicate_bool b n)) -let duplicate_oracle b n = bool_of_bitU_oracle b >>= (fun b -> return (duplicate_bool b n)) +let duplicate_nondet b n = bool_of_bitU_nondet b >>= (fun b -> return (duplicate_bool b n)) let duplicate b n = maybe_failwith (duplicate_maybe b n) val reverse_endianness : forall 'a. Size 'a => mword 'a -> mword 'a diff --git a/src/gen_lib/sail2_prompt.lem b/src/gen_lib/sail2_prompt.lem index 08a47052..e01cc051 100644 --- a/src/gen_lib/sail2_prompt.lem +++ b/src/gen_lib/sail2_prompt.lem @@ -51,31 +51,31 @@ let bool_of_bitU_fail = function | BU -> Fail "bool_of_bitU" end -val bool_of_bitU_oracle : forall 'rv 'e. bitU -> monad 'rv bool 'e -let bool_of_bitU_oracle = function +val bool_of_bitU_nondet : forall 'rv 'e. bitU -> monad 'rv bool 'e +let bool_of_bitU_nondet = function | B0 -> return false | B1 -> return true | BU -> undefined_bool () end -val bools_of_bits_oracle : forall 'rv 'e. list bitU -> monad 'rv (list bool) 'e -let bools_of_bits_oracle bits = +val bools_of_bits_nondet : forall 'rv 'e. list bitU -> monad 'rv (list bool) 'e +let bools_of_bits_nondet bits = foreachM bits [] (fun b bools -> - bool_of_bitU_oracle b >>= (fun b -> + bool_of_bitU_nondet b >>= (fun b -> return (bools ++ [b]))) -val of_bits_oracle : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monad 'rv 'a 'e -let of_bits_oracle bits = - bools_of_bits_oracle bits >>= (fun bs -> +val of_bits_nondet : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monad 'rv 'a 'e +let of_bits_nondet bits = + bools_of_bits_nondet bits >>= (fun bs -> return (of_bools bs)) val of_bits_fail : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monad 'rv 'a 'e let of_bits_fail bits = maybe_fail "of_bits" (of_bits bits) -val mword_oracle : forall 'rv 'a 'e. Size 'a => unit -> monad 'rv (mword 'a) 'e -let mword_oracle () = - bools_of_bits_oracle (repeat [BU] (integerFromNat size)) >>= (fun bs -> +val mword_nondet : forall 'rv 'a 'e. Size 'a => unit -> monad 'rv (mword 'a) 'e +let mword_nondet () = + bools_of_bits_nondet (repeat [BU] (integerFromNat size)) >>= (fun bs -> return (wordFromBitlist bs)) val whileM : forall 'rv 'vars 'e. 'vars -> ('vars -> monad 'rv bool 'e) -> @@ -93,6 +93,16 @@ let rec untilM vars cond body = cond vars >>= fun cond_val -> if cond_val then return vars else untilM vars cond body +val internal_pick : forall 'rv 'a 'e. list 'a -> monad 'rv 'a 'e +let internal_pick xs = + (* Use sufficiently many undefined bits and convert into an index into the list *) + bools_of_bits_nondet (repeat [BU] (length_list xs)) >>= fun bs -> + let idx = (natFromNatural (nat_of_bools bs)) mod List.length xs in + match index xs idx with + | Just x -> return x + | Nothing -> Fail "internal_pick" + end + (*let write_two_regs r1 r2 vec = let is_inc = let is_inc_r1 = is_inc_of_reg r1 in diff --git a/src/gen_lib/sail2_prompt_monad.lem b/src/gen_lib/sail2_prompt_monad.lem index 745589e2..78b1615e 100644 --- a/src/gen_lib/sail2_prompt_monad.lem +++ b/src/gen_lib/sail2_prompt_monad.lem @@ -29,6 +29,7 @@ type monad 'regval 'a 'e = | Read_reg of register_name * ('regval -> monad 'regval 'a 'e) (* Request to write register *) | Write_reg of register_name * 'regval * monad 'regval 'a 'e + (* Request to choose a Boolean, e.g. to resolve an undefined bit *) | Undefined of (bool -> monad 'regval 'a 'e) (* Print debugging or tracing information *) | Print of string * monad 'regval 'a 'e diff --git a/src/gen_lib/sail2_state.lem b/src/gen_lib/sail2_state.lem index 82ac35d8..f703dead 100644 --- a/src/gen_lib/sail2_state.lem +++ b/src/gen_lib/sail2_state.lem @@ -41,31 +41,31 @@ let bool_of_bitU_fail = function | BU -> failS "bool_of_bitU" end -val bool_of_bitU_oracleS : forall 'rv 'e. bitU -> monadS 'rv bool 'e -let bool_of_bitU_oracleS = function +val bool_of_bitU_nondetS : forall 'rv 'e. bitU -> monadS 'rv bool 'e +let bool_of_bitU_nondetS = function | B0 -> returnS false | B1 -> returnS true | BU -> undefined_boolS () end -val bools_of_bits_oracleS : forall 'rv 'e. list bitU -> monadS 'rv (list bool) 'e -let bools_of_bits_oracleS bits = +val bools_of_bits_nondetS : forall 'rv 'e. list bitU -> monadS 'rv (list bool) 'e +let bools_of_bits_nondetS bits = foreachS bits [] (fun b bools -> - bool_of_bitU_oracleS b >>$= (fun b -> + bool_of_bitU_nondetS b >>$= (fun b -> returnS (bools ++ [b]))) -val of_bits_oracleS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e -let of_bits_oracleS bits = - bools_of_bits_oracleS bits >>$= (fun bs -> +val of_bits_nondetS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e +let of_bits_nondetS bits = + bools_of_bits_nondetS bits >>$= (fun bs -> returnS (of_bools bs)) val of_bits_failS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e let of_bits_failS bits = maybe_failS "of_bits" (of_bits bits) -val mword_oracleS : forall 'rv 'a 'e. Size 'a => unit -> monadS 'rv (mword 'a) 'e -let mword_oracleS () = - bools_of_bits_oracleS (repeat [BU] (integerFromNat size)) >>$= (fun bs -> +val mword_nondetS : forall 'rv 'a 'e. Size 'a => unit -> monadS 'rv (mword 'a) 'e +let mword_nondetS () = + bools_of_bits_nondetS (repeat [BU] (integerFromNat size)) >>$= (fun bs -> returnS (wordFromBitlist bs)) @@ -83,3 +83,13 @@ let rec untilS vars cond body s = (body vars >>$= (fun vars s' -> (cond vars >>$= (fun cond_val s'' -> if cond_val then returnS vars s'' else untilS vars cond body s'')) s')) s + +val internal_pickS : forall 'rv 'a 'e. list 'a -> monadS 'rv 'a 'e +let internal_pickS xs = + (* Use sufficiently many undefined bits and convert into an index into the list *) + bools_of_bits_nondetS (repeat [BU] (length_list xs)) >>$= fun bs -> + let idx = (natFromNatural (nat_of_bools bs)) mod List.length xs in + match index xs idx with + | Just x -> returnS x + | Nothing -> failS "internal_pick" + end diff --git a/src/gen_lib/sail2_state_monad.lem b/src/gen_lib/sail2_state_monad.lem index f207699f..30b296cc 100644 --- a/src/gen_lib/sail2_state_monad.lem +++ b/src/gen_lib/sail2_state_monad.lem @@ -13,20 +13,15 @@ type sequential_state 'regs = memstate : memstate; tagstate : tagstate; write_ea : maybe (write_kind * integer * integer); - last_exclusive_operation_was_load : bool; - (* Random bool generator for use as an undefined bit oracle *) - next_bool : nat -> (bool * nat); - seed : nat |> + last_exclusive_operation_was_load : bool |> -val init_state : forall 'regs. 'regs -> (nat -> (bool* nat)) -> nat -> sequential_state 'regs -let init_state regs o s = +val init_state : forall 'regs. 'regs -> sequential_state 'regs +let init_state regs = <| regstate = regs; memstate = Map.empty; tagstate = Map.empty; write_ea = Nothing; - last_exclusive_operation_was_load = false; - next_bool = o; - seed = s |> + last_exclusive_operation_was_load = false |> type ex 'e = | Failure of string @@ -69,10 +64,7 @@ val failS : forall 'regs 'a 'e. string -> monadS 'regs 'a 'e let failS msg s = {(Ex (Failure msg), s)} val undefined_boolS : forall 'regval 'regs 'a 'e. unit -> monadS 'regs bool 'e -let undefined_boolS () = - readS (fun s -> s.next_bool (s.seed)) >>$= (fun (b, seed) -> - updateS (fun s -> <| s with seed = seed |>) >>$ - returnS b) +let undefined_boolS () = chooseS {false; true} val exitS : forall 'regs 'e 'a. unit -> monadS 'regs 'a 'e let exitS () = failS "exit" diff --git a/src/gen_lib/sail2_string.lem b/src/gen_lib/sail2_string.lem index d0e40ad4..ba3a2d51 100644 --- a/src/gen_lib/sail2_string.lem +++ b/src/gen_lib/sail2_string.lem @@ -31,7 +31,7 @@ let string_append = stringAppend val maybeIntegerOfString : string -> maybe integer let maybeIntegerOfString _ = Nothing (* TODO FIXME *) -declare ocaml target_rep function maybeIntegerOfString = `(fun s -> match int_of_string_opt s with None -> None | Some i -> Some (Nat_big_num.of_int i))` +declare ocaml target_rep function maybeIntegerOfString = `(fun s -> match int_of_string s with i -> Some (Nat_big_num.of_int i) | exception Failure _ -> None )` (*********************************************** * end stuff that should be in Lem Num_extra * diff --git a/src/gen_lib/sail2_values.lem b/src/gen_lib/sail2_values.lem index 0f384cab..fd742fb1 100644 --- a/src/gen_lib/sail2_values.lem +++ b/src/gen_lib/sail2_values.lem @@ -47,12 +47,14 @@ let power_real b e = realPowInteger b e*) val print_endline : string -> unit let print_endline _ = () -declare ocaml target_rep function print_endline = `print_endline` +(* declare ocaml target_rep function print_endline = `print_endline` *) val prerr_endline : string -> unit let prerr_endline _ = () declare ocaml target_rep function prerr_endline = `prerr_endline` +let prerr x = prerr_endline x + val print_int : string -> integer -> unit let print_int msg i = print_endline (msg ^ (stringFromInteger i)) diff --git a/src/interpreter.ml b/src/interpreter.ml index 00846d73..99d5889a 100644 --- a/src/interpreter.ml +++ b/src/interpreter.ml @@ -232,7 +232,6 @@ let is_value_fexp (FE_aux (FE_Fexp (id, exp), _)) = is_value exp let value_of_fexp (FE_aux (FE_Fexp (id, exp), _)) = (string_of_id id, value_of_exp exp) let rec build_letchain id lbs (E_aux (_, annot) as exp) = - (* print_endline ("LETCHAIN " ^ string_of_exp exp); *) match lbs with | [] -> exp | lb :: lbs when IdSet.mem id (letbind_pat_ids lb)-> @@ -311,7 +310,6 @@ let rec step (E_aux (e_aux, annot) as orig_exp) = else failwith "Match failure" - | E_vector_subrange (vec, n, m) -> wrap (E_app (mk_id "vector_subrange_dec", [vec; n; m])) | E_vector_access (vec, n) -> diff --git a/src/isail.ml b/src/isail.ml index 593167f9..4adc1cd2 100644 --- a/src/isail.ml +++ b/src/isail.ml @@ -127,7 +127,12 @@ let rec run () = print_endline ("Result = " ^ Value.string_of_value v); current_mode := Normal | Step (out, state, _, stack) -> - current_mode := Evaluation (eval_frame !interactive_ast frame); + begin + try + current_mode := Evaluation (eval_frame !interactive_ast frame) + with + | Failure str -> print_endline str; current_mode := Normal + end; run () | Break frame -> print_endline "Breakpoint"; @@ -147,7 +152,12 @@ let rec run_steps n = print_endline ("Result = " ^ Value.string_of_value v); current_mode := Normal | Step (out, state, _, stack) -> - current_mode := Evaluation (eval_frame !interactive_ast frame); + begin + try + current_mode := Evaluation (eval_frame !interactive_ast frame) + with + | Failure str -> print_endline str; current_mode := Normal + end; run_steps (n - 1) | Break frame -> print_endline "Breakpoint"; @@ -352,9 +362,14 @@ let handle_input' input = print_endline ("Result = " ^ Value.string_of_value v); current_mode := Normal | Step (out, state, _, stack) -> - interactive_state := state; - current_mode := Evaluation (eval_frame !interactive_ast frame); - print_program () + begin + try + interactive_state := state; + current_mode := Evaluation (eval_frame !interactive_ast frame); + print_program () + with + | Failure str -> print_endline str; current_mode := Normal + end | Break frame -> print_endline "Breakpoint"; current_mode := Evaluation frame diff --git a/src/lem_interp/sail2_instr_kinds.lem b/src/lem_interp/sail2_instr_kinds.lem index 13e5304e..3d238676 100644 --- a/src/lem_interp/sail2_instr_kinds.lem +++ b/src/lem_interp/sail2_instr_kinds.lem @@ -151,6 +151,10 @@ type barrier_kind = | Barrier_RISCV_r_r | Barrier_RISCV_rw_w | Barrier_RISCV_w_w + | Barrier_RISCV_w_rw + | Barrier_RISCV_rw_r + | Barrier_RISCV_r_w + | Barrier_RISCV_w_r | Barrier_RISCV_i (* X86 *) | Barrier_x86_MFENCE @@ -176,6 +180,10 @@ instance (Show barrier_kind) | Barrier_RISCV_r_r -> "Barrier_RISCV_r_r" | Barrier_RISCV_rw_w -> "Barrier_RISCV_rw_w" | Barrier_RISCV_w_w -> "Barrier_RISCV_w_w" + | Barrier_RISCV_w_rw -> "Barrier_RISCV_w_rw" + | Barrier_RISCV_rw_r -> "Barrier_RISCV_rw_r" + | Barrier_RISCV_r_w -> "Barrier_RISCV_r_w" + | Barrier_RISCV_w_r -> "Barrier_RISCV_w_r" | Barrier_RISCV_i -> "Barrier_RISCV_i" | Barrier_x86_MFENCE -> "Barrier_x86_MFENCE" end @@ -211,7 +219,7 @@ instance (Show instruction_kind) | IK_mem_read read_kind -> "IK_mem_read " ^ (show read_kind) | IK_mem_write write_kind -> "IK_mem_write " ^ (show write_kind) | IK_mem_rmw (r, w) -> "IK_mem_rmw " ^ (show r) ^ " " ^ (show w) - | IK_branch -> "IK_branch" + | IK_branch () -> "IK_branch" | IK_trans trans_kind -> "IK_trans " ^ (show trans_kind) | IK_simple () -> "IK_simple" end @@ -288,7 +296,11 @@ instance (EnumerationType barrier_kind) | Barrier_RISCV_r_r -> 15 | Barrier_RISCV_rw_w -> 16 | Barrier_RISCV_w_w -> 17 - | Barrier_RISCV_i -> 18 - | Barrier_x86_MFENCE -> 19 + | Barrier_RISCV_w_rw -> 18 + | Barrier_RISCV_rw_r -> 19 + | Barrier_RISCV_r_w -> 20 + | Barrier_RISCV_w_r -> 21 + | Barrier_RISCV_i -> 22 + | Barrier_x86_MFENCE -> 23 end end diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml index 5ffb1647..236c4222 100644 --- a/src/ocaml_backend.ml +++ b/src/ocaml_backend.ml @@ -167,10 +167,13 @@ let ocaml_lit (L_aux (lit_aux, _)) = | L_one -> string "B1" | L_true -> string "true" | L_false -> string "false" - | L_num n -> if Big_int.equal n Big_int.zero - then string "Big_int.zero" - else parens (string "Big_int.of_int" ^^ space - ^^ string "(" ^^ string (Big_int.to_string n) ^^ string ")") + | L_num n -> + if Big_int.equal n Big_int.zero then + string "Big_int.zero" + else if Big_int.less_equal (Big_int.of_int min_int) n && Big_int.less_equal n (Big_int.of_int max_int) then + parens (string "Big_int.of_int" ^^ space ^^ parens (string (Big_int.to_string n))) + else + parens (string "Big_int.of_string" ^^ space ^^ dquotes (string (Big_int.to_string n))) | L_undef -> failwith "undefined should have been re-written prior to ocaml backend" | L_string str -> string_lit str | L_real str -> parens (string "real_of_string" ^^ space ^^ dquotes (string (String.escaped str))) @@ -389,6 +392,10 @@ let ocaml_dec_spec ctx (DEC_aux (reg, _)) = separate space [string "let"; zencode ctx id; colon; parens (ocaml_typ ctx typ); string "ref"; equals; string "ref"; parens (ocaml_exp ctx (initial_value_for id ctx.register_inits))] + | DEC_config (id, typ, exp) -> + separate space [string "let"; zencode ctx id; colon; + parens (ocaml_typ ctx typ); string "ref"; equals; + string "ref"; parens (ocaml_exp ctx exp)] | _ -> failwith "Unsupported register declaration" let first_function = ref true diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index ffe376e0..74e97a29 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -96,12 +96,14 @@ type context = { kid_renames : kid KBindings.t; (* Plain tyvar -> tyvar renames *) kid_id_renames : id KBindings.t; (* tyvar -> argument renames *) bound_nexps : NexpSet.t; + build_ex_return : bool; } let empty_ctxt = { early_ret = false; kid_renames = KBindings.empty; kid_id_renames = KBindings.empty; - bound_nexps = NexpSet.empty + bound_nexps = NexpSet.empty; + build_ex_return = false; } let langlebar = string "<|" @@ -135,7 +137,10 @@ let rec fix_id remove_tick name = match name with | "GT" | "EQ" | "Z" + | "O" + | "S" | "mod" + | "M" -> name ^ "'" | _ -> if String.contains name '#' then @@ -146,15 +151,17 @@ let rec fix_id remove_tick name = match name with fix_id remove_tick (String.concat "__" (Util.split_on_char '^' name)) else if name.[0] = '\'' then let var = String.sub name 1 (String.length name - 1) in - if remove_tick then var else (var ^ "'") + if remove_tick then fix_id remove_tick var else (var ^ "'") else if is_number_char(name.[0]) then ("v" ^ name ^ "'") else name -let doc_id (Id_aux(i,_)) = +let string_id (Id_aux(i,_)) = match i with - | Id i -> string (fix_id false i) - | DeIid x -> string (Util.zencode_string ("op " ^ x)) + | Id i -> fix_id false i + | DeIid x -> Util.zencode_string ("op " ^ x) + +let doc_id id = string (string_id id) let doc_id_type (Id_aux(i,_)) = match i with @@ -318,7 +325,7 @@ let drop_duplicate_atoms kids ty = in aux_typ ty (* TODO: parens *) -let rec doc_nc ctx (NC_aux (nc,_)) = +let rec doc_nc_prop ctx (NC_aux (nc,_)) = match nc with | NC_equal (ne1, ne2) -> doc_op equals (doc_nexp ctx ne1) (doc_nexp ctx ne2) | NC_bounded_ge (ne1, ne2) -> doc_op (string ">=") (doc_nexp ctx ne1) (doc_nexp ctx ne2) @@ -328,11 +335,27 @@ let rec doc_nc ctx (NC_aux (nc,_)) = separate space [string "In"; doc_var_lem ctx kid; brackets (separate (string "; ") (List.map (fun i -> string (Nat_big_num.to_string i)) is))] - | NC_or (nc1, nc2) -> doc_op (string "\\/") (doc_nc ctx nc1) (doc_nc ctx nc2) - | NC_and (nc1, nc2) -> doc_op (string "/\\") (doc_nc ctx nc1) (doc_nc ctx nc2) + | NC_or (nc1, nc2) -> doc_op (string "\\/") (doc_nc_prop ctx nc1) (doc_nc_prop ctx nc2) + | NC_and (nc1, nc2) -> doc_op (string "/\\") (doc_nc_prop ctx nc1) (doc_nc_prop ctx nc2) | NC_true -> string "True" | NC_false -> string "False" +(* TODO: parens *) +let rec doc_nc_exp ctx (NC_aux (nc,_)) = + match nc with + | NC_equal (ne1, ne2) -> doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_bounded_ge (ne1, ne2) -> doc_op (string ">=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_bounded_le (ne1, ne2) -> doc_op (string "<=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_not_equal (ne1, ne2) -> string "negb" ^^ space ^^ parens (doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)) + | NC_set (kid, is) -> (* TODO: is this a good translation? *) + separate space [string "member_Z_list"; doc_var_lem ctx kid; + brackets (separate (string "; ") + (List.map (fun i -> string (Nat_big_num.to_string i)) is))] + | NC_or (nc1, nc2) -> doc_op (string "||") (doc_nc_exp ctx nc1) (doc_nc_exp ctx nc2) + | NC_and (nc1, nc2) -> doc_op (string "&&") (doc_nc_exp ctx nc1) (doc_nc_exp ctx nc2) + | NC_true -> string "true" + | NC_false -> string "false" + let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) = match typ with | Typ_app(Id_aux (Id "range", _), [Typ_arg_aux(Typ_arg_nexp low,_); @@ -347,7 +370,7 @@ let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) = let expand_range_type typ = Util.option_default typ (maybe_expand_range_type typ) let doc_arithfact ctxt nc = - string "ArithFact" ^^ space ^^ parens (doc_nc ctxt nc) + string "ArithFact" ^^ space ^^ parens (doc_nc_prop ctxt nc) (* When making changes here, check whether they affect lem_tyvars_of_typ *) let doc_typ, doc_atomic_typ = @@ -381,7 +404,7 @@ let doc_typ, doc_atomic_typ = let tpp = match elem_typ with | Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) -> string "mword " ^^ doc_nexp ctx (nexp_simp m) - | _ -> string "list" ^^ space ^^ typ elem_typ in + | _ -> string "vec" ^^ space ^^ typ elem_typ ^^ space ^^ doc_nexp ctx (nexp_simp m) in if atyp_needed then parens tpp else tpp | Typ_app(Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ etyp, _)]) -> let tpp = string "register_ref regstate register_value " ^^ typ etyp in @@ -424,7 +447,7 @@ let doc_typ, doc_atomic_typ = List.fold_left add_tyvar tpp kids | None -> match nc with - | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids) +(* | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids)*) | _ -> List.fold_left add_tyvar (doc_arithfact ctx nc) kids end and doc_typ_arg (Typ_arg_aux(t,_)) = match t with @@ -537,9 +560,9 @@ let doc_typquant_items ctx delimit (TypQ_aux (tq,_)) = let doc_typquant_items_separate ctx delimit (TypQ_aux (tq,_)) = match tq with | TypQ_tq qis -> - separate_opt space (doc_quant_item_id ctx delimit) qis, - separate_opt space (doc_quant_item_constr ctx delimit) qis - | TypQ_no_forall -> empty, empty + Util.map_filter (doc_quant_item_id ctx delimit) qis, + Util.map_filter (doc_quant_item_constr ctx delimit) qis + | TypQ_no_forall -> [], [] let doc_typquant ctx (TypQ_aux(tq,_)) typ = match tq with | TypQ_tq ((_ :: _) as qs) -> @@ -687,14 +710,34 @@ let typ_id_of (Typ_aux (typ, l)) = match typ with | Typ_app (id, _) -> id | _ -> raise (Reporting_basic.err_unreachable l "failed to get type id") +(* TODO: maybe Nexp_exp, division? *) +(* Evaluation of constant nexp subexpressions, because Coq will be able to do those itself *) +let rec nexp_const_eval (Nexp_aux (n,l) as nexp) = + let binop f re l n1 n2 = + match nexp_const_eval n1, nexp_const_eval n2 with + | Nexp_aux (Nexp_constant c1,_), Nexp_aux (Nexp_constant c2,_) -> + Nexp_aux (Nexp_constant (f c1 c2),l) + | n1', n2' -> Nexp_aux (re n1' n2',l) + in + let unop f re l n1 = + match nexp_const_eval n1 with + | Nexp_aux (Nexp_constant c1,_) -> Nexp_aux (Nexp_constant (f c1),l) + | n1' -> Nexp_aux (re n1',l) + in + match n with + | Nexp_times (n1,n2) -> binop Big_int.mul (fun n1 n2 -> Nexp_times (n1,n2)) l n1 n2 + | Nexp_sum (n1,n2) -> binop Big_int.add (fun n1 n2 -> Nexp_sum (n1,n2)) l n1 n2 + | Nexp_minus (n1,n2) -> binop Big_int.sub (fun n1 n2 -> Nexp_minus (n1,n2)) l n1 n2 + | Nexp_neg n1 -> unop Big_int.negate (fun n -> Nexp_neg n) l n1 + | _ -> nexp + (* Decide whether two nexps used in a vector size are similar; if not a cast will be inserted *) -let similar_nexps n1 n2 = +let similar_nexps env n1 n2 = let rec same_nexp_shape (Nexp_aux (n1,_)) (Nexp_aux (n2,_)) = match n1, n2 with - | Nexp_id _, Nexp_id _ - | Nexp_var _, Nexp_var _ - -> true + | Nexp_id _, Nexp_id _ -> true + | Nexp_var k1, Nexp_var k2 -> prove env (nc_eq (nvar k1) (nvar k2)) | Nexp_constant c1, Nexp_constant c2 -> Nat_big_num.equal c1 c2 | Nexp_app (f1,args1), Nexp_app (f2,args2) -> Id.compare f1 f2 == 0 && List.for_all2 same_nexp_shape args1 args2 @@ -706,7 +749,48 @@ let similar_nexps n1 n2 = | Nexp_neg n1, Nexp_neg n2 -> same_nexp_shape n1 n2 | _ -> false - in if same_nexp_shape n1 n2 then true else false + in if same_nexp_shape (nexp_const_eval n1) (nexp_const_eval n2) then true else false + +let constraint_fns = ["Z.leb"; "Z.geb"; "Z.ltb"; "Z.gtb"; "Z.eqb"; "neq_atom"] + +let condition_produces_constraint exp = + (* Cheat a little - this isn't quite the right environment for subexpressions + but will have all of the relevant functions in it. *) + let env = env_of exp in + Rewriter.fold_exp + { (Rewriter.pure_exp_alg false (||)) with + Rewriter.e_app = fun (f,bs) -> + List.exists (fun x -> x) bs || + (let name = if Env.is_extern f env "coq" + then Env.get_extern f env "coq" + else string_id f in + List.exists (fun id -> String.compare name id == 0) constraint_fns) + } exp + +(* For most functions whose return types are non-trivial atoms we return a + dependent pair with a proof that the result is the expected integer. This + is redundant for basic arithmetic functions and functions which we unfold + in the constraint solver. *) +let no_Z_proof_fns = ["Z.add"; "Z.sub"; "Z.opp"; "Z.mul"; "length_mword"; "length"] + +let is_no_Z_proof_fn env id = + if Env.is_extern id env "coq" + then + let s = Env.get_extern id env "coq" in + List.exists (fun x -> String.compare x s == 0) no_Z_proof_fns + else false + +let replace_atom_return_type ret_typ = + (* TODO: more complex uses of atom *) + match ret_typ with + | Typ_aux (Typ_app (Id_aux (Id "atom",_), [Typ_arg_aux (Typ_arg_nexp nexp,_)]),l) -> + let kid = mk_kid "_retval" in (* TODO: collision avoidance *) + true, Typ_aux (Typ_exist ([kid], nc_eq (nvar kid) nexp, atom_typ (nvar kid)),Generated l) + | Typ_aux (Typ_id (Id_aux (Id "nat",_)),l) -> + let kid = mk_kid "_retval" in + true, Typ_aux (Typ_exist ([kid], nc_gteq (nvar kid) (nconstant Nat_big_num.zero), atom_typ (nvar kid)),Generated l) + | _ -> false, ret_typ + let prefix_recordtype = true let report = Reporting_basic.err_unreachable @@ -815,7 +899,7 @@ let doc_exp_lem, doc_let_lem = | Id_aux (Id "foreach", _) -> begin match args with - | [exp1; exp2; exp3; ord_exp; vartuple; body] -> + | [from_exp; to_exp; step_exp; ord_exp; vartuple; body] -> let loopvar, body = match body with | E_aux (E_let (LB_aux (LB_val (_, _), _), E_aux (E_let (LB_aux (LB_val (_, _), _), @@ -826,13 +910,13 @@ let doc_exp_lem, doc_let_lem = | (P_aux (P_id id, _))), _), _), body), _), _), _)), _)), _) -> id, body | _ -> raise (Reporting_basic.err_unreachable l ("Unable to find loop variable in " ^ string_of_exp body)) in - let step = match ord_exp with - | E_aux (E_lit (L_aux (L_false, _)), _) -> - parens (separate space [string "integerNegate"; expY exp3]) - | _ -> expY exp3 + let dir = match ord_exp with + | E_aux (E_lit (L_aux (L_false, _)), _) -> "_down" + | E_aux (E_lit (L_aux (L_true, _)), _) -> "_up" + | _ -> raise (Reporting_basic.err_unreachable l ("Unexpected loop direction " ^ string_of_exp ord_exp)) in - let combinator = if effectful (effect_of body) then "foreachM" else "foreach" in - let indices_pp = parens (separate space [string "index_list"; expY exp1; expY exp2; step]) in + let combinator = if effectful (effect_of body) then "foreach_ZM" else "foreach_Z" in + let combinator = combinator ^ dir in let used_vars_body = find_e_ids body in let body_lambda = (* Work around indentation issues in Lem when translating @@ -840,18 +924,20 @@ let doc_exp_lem, doc_let_lem = match fst (uncast_exp vartuple) with | E_aux (E_tuple _, _) when not (IdSet.mem (mk_id "varstup") used_vars_body)-> - separate space [string "fun"; doc_id loopvar; string "varstup"; bigarrow] + separate space [string "fun"; doc_id loopvar; string "_"; string "varstup"; bigarrow] ^^ break 1 ^^ - separate space [string "let"; expY vartuple; string ":= varstup in"] + separate space [string "let"; squote ^^ expY vartuple; string ":= varstup in"] | E_aux (E_lit (L_aux (L_unit, _)), _) when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> - separate space [string "fun"; doc_id loopvar; string "unit_var"; bigarrow] + separate space [string "fun"; doc_id loopvar; string "_"; string "unit_var"; bigarrow] | _ -> - separate space [string "fun"; doc_id loopvar; expY vartuple; bigarrow] + separate space [string "fun"; doc_id loopvar; string "_"; expY vartuple; bigarrow] in parens ( (prefix 2 1) - ((separate space) [string combinator; indices_pp; expY vartuple]) + ((separate space) [string combinator; + expY from_exp; expY to_exp; expY step_exp; + expY vartuple]) (parens (prefix 2 1 (group body_lambda) (expN body)) ) @@ -879,7 +965,7 @@ let doc_exp_lem, doc_let_lem = | E_aux (E_tuple _, _) when not (IdSet.mem (mk_id "varstup") used_vars_body)-> separate space [string "fun varstup"; bigarrow] ^^ break 1 ^^ - separate space [string "let"; expY varstuple; string ":= varstup in"] + separate space [string "let"; squote ^^ expY varstuple; string ":= varstup in"] | E_aux (E_lit (L_aux (L_unit, _)), _) when not (IdSet.mem (mk_id "unit_var") used_vars_body) -> separate space [string "fun unit_var"; bigarrow] @@ -946,17 +1032,35 @@ let doc_exp_lem, doc_let_lem = (* TODO: more sophisticated check *) match destruct_exist env arg_ty, destruct_exist env typ_from_fn with | Some _, None -> parens (string "projT1 " ^^ arg_pp) + (* Usually existentials have already been built elsewhere, but this + is useful for (e.g.) ranges *) + | None, Some _ -> parens (string "build_ex " ^^ arg_pp) | _, _ -> arg_pp in let epp = hang 2 (flow (break 1) (call :: List.map2 doc_arg args arg_typs)) in - (* Unpack existential result *) - let inst = instantiation_of full_exp in + (* Decide whether to unpack an existential result, pack one, or cast. + To do this we compare the expected type stored in the checked expression + with the inferred type. *) + let inst = + match instantiation_of_without_type full_exp with + | x -> x + (* Not all function applications can be inferred, so try falling back to the + type inferred when we know the target type. + TODO: there are probably some edge cases where this won't pick up a need + to cast. *) + | exception _ -> instantiation_of full_exp + in let inst = KBindings.fold (fun k u m -> KBindings.add (orig_kid k) u m) inst KBindings.empty in - let ret_typ_inst = subst_unifiers inst ret_typ in + let ret_typ_inst = + subst_unifiers inst ret_typ + in let unpack,build_ex,autocast = let ann_typ = Env.expand_synonyms env (typ_of_annot (l,annot)) in let ann_typ = expand_range_type ann_typ in let ret_typ_inst = expand_range_type (Env.expand_synonyms env ret_typ_inst) in + let ret_typ_inst = + if is_no_Z_proof_fn env f then ret_typ_inst + else snd (replace_atom_return_type ret_typ_inst) in let unpack, build_ex, in_typ, out_typ = match ret_typ_inst, ann_typ with | Typ_aux (Typ_exist (_,_,t1),_), Typ_aux (Typ_exist (_,_,t2),_) -> @@ -968,10 +1072,12 @@ let doc_exp_lem, doc_let_lem = | t1, t2 -> false,false,t1,t2 in let autocast = - match destruct_vector env in_typ, destruct_vector env out_typ with - | Some (n1,_,t1), Some (n2,_,t2) - when is_bit_typ t1 && is_bit_typ t2 -> - not (similar_nexps n1 n2) + (* Avoid using helper functions which simplify the nexps *) + is_bitvector_typ in_typ && is_bitvector_typ out_typ && + match in_typ, out_typ with + | Typ_aux (Typ_app (_,[Typ_arg_aux (Typ_arg_nexp n1,_);_;_]),_), + Typ_aux (Typ_app (_,[Typ_arg_aux (Typ_arg_nexp n2,_);_;_]),_) -> + not (similar_nexps env n1 n2) | _ -> false in unpack,build_ex,autocast in @@ -1048,13 +1154,33 @@ let doc_exp_lem, doc_let_lem = (doc_fexp ctxt recordtyp) fexps)) in if aexp_needed then parens epp else epp | E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) -> - let recordtyp = match annot with + let recordtyp, env = match annot with | Some (env, Typ_aux (Typ_id tid,_), _) | Some (env, Typ_aux (Typ_app (tid, _), _), _) when Env.is_record tid env -> - tid + tid, env | _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in - enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps)) + if List.length fexps > 1 then + let _,fields = Env.get_record recordtyp env in + let var, let_pp = + match e with + | E_aux (E_id id,_) -> id, empty + | _ -> let v = mk_id "_record" in (* TODO: collision avoid *) + v, separate space [string "let "; doc_id v; coloneq; top_exp ctxt true e; string "in"] ^^ break 1 + in + let doc_field (_,id) = + match List.find (fun (FE_aux (FE_Fexp (id',_),_)) -> Id.compare id id' == 0) fexps with + | fexp -> doc_fexp ctxt recordtyp fexp + | exception Not_found -> + let fname = + if prefix_recordtype && string_of_id recordtyp <> "regstate" + then (string (string_of_id recordtyp ^ "_")) ^^ doc_id id + else doc_id id in + doc_op coloneq fname (doc_id var ^^ dot ^^ parens fname) + in let_pp ^^ enclose_record (align (separate_map (semi_sp ^^ break 1) + doc_field fields)) + else + enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps)) | E_vector exps -> let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in let start, (len, order, etyp) = @@ -1079,7 +1205,9 @@ let doc_exp_lem, doc_let_lem = if is_bit_typ etyp then let bepp = string "vec_of_bits" ^^ space ^^ align epp in (align (group (prefix 0 1 bepp (doc_tannot_lem ctxt (env_of full_exp) false t))), true) - else (epp,aexp_needed) in + else + let vepp = string "vec_of_list_len" ^^ space ^^ align epp in + (vepp,aexp_needed) in if aexp_needed then parens (align epp) else epp | E_vector_update(v,e1,e2) -> raise (Reporting_basic.err_unreachable l @@ -1100,7 +1228,8 @@ let doc_exp_lem, doc_let_lem = if effectful (effect_of e) then let try_catch = if ctxt.early_ret then "try_catchR" else "try_catch" in let epp = - group ((separate space [string try_catch; expY e; string "(function "]) ^/^ + (* TODO capture avoidance for __catch_val *) + group ((separate space [string try_catch; expY e; string "(fun __catch_val => match __catch_val with "]) ^/^ (separate_map (break 1) (doc_case ctxt exc_typ) pexps) ^/^ (string "end)")) in if aexp_needed then parens (align epp) else align epp @@ -1119,24 +1248,37 @@ let doc_exp_lem, doc_let_lem = | E_var(lexp, eq_exp, in_exp) -> raise (report l "E_vars should have been removed before pretty-printing") | E_internal_plet (pat,e1,e2) -> - let epp = - let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in - let middle = - match fst (untyp_pat pat) with - | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> - string ">>" - | P_aux (P_id id,_) -> - separate space [string ">>= fun"; doc_id id; bigarrow] - | P_aux (P_typ (typ, P_aux (P_id id,_)),_) -> - separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt typ; bigarrow] - | _ -> - separate space [string ">>= fun"; squote ^^ doc_pat ctxt true (pat, typ_of e1); bigarrow] - in - infix 0 1 middle (expV b e1) (expN e2) - in - if aexp_needed then parens (align epp) else epp + begin + match pat, e1 with + | (P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _)), + (E_aux (E_assert (assert_e1,assert_e2),_)) -> + let epp = liftR (separate space [string "assert_exp'"; expY assert_e1; expY assert_e2]) in + let epp = infix 0 1 (string ">>= fun _ =>") epp (expN e2) in + if aexp_needed then parens (align epp) else align epp + | _ -> + let epp = + let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in + let middle = + match fst (untyp_pat pat) with + | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> + string ">>" + | P_aux (P_id id,_) -> + separate space [string ">>= fun"; doc_id id; bigarrow] + | P_aux (P_typ (typ, P_aux (P_id id,_)),_) -> + separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt typ; bigarrow] + | _ -> + separate space [string ">>= fun"; squote ^^ doc_pat ctxt true (pat, typ_of e1); bigarrow] + in + infix 0 1 middle (expV b e1) (expN e2) + in + if aexp_needed then parens (align epp) else epp + end | E_internal_return (e1) -> - wrap_parens (align (separate space [string "returnm"; expY e1])) + let e1pp = expY e1 in + let valpp = if ctxt.build_ex_return + then parens (string "build_ex" ^^ space ^^ e1pp) + else e1pp in + wrap_parens (align (separate space [string "returnm"; valpp])) | E_sizeof nexp -> (match nexp_simp nexp with | Nexp_aux (Nexp_constant i, _) -> doc_lit (L_aux (L_num i, l)) @@ -1153,7 +1295,7 @@ let doc_exp_lem, doc_let_lem = parens (doc_typ ctxt (typ_of full_exp)); parens (doc_typ ctxt (typ_of r))] in align (parens (string "early_return" ^//^ expV true r ^//^ ta)) - | E_constraint _ -> string "true" + | E_constraint nc -> wrap_parens (doc_nc_exp ctxt nc) | E_comment _ | E_comment_struc _ -> empty | E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _ | E_internal_exp_user _ | E_internal_value _ -> @@ -1168,7 +1310,9 @@ let doc_exp_lem, doc_let_lem = | _ -> prefix 2 1 (string "else") (top_exp ctxt false e) in (prefix 2 1 - (soft_surround 2 1 if_pp (string "sumbool_of_bool" ^^ space ^^ parens (top_exp ctxt true c)) (string "then")) + (soft_surround 2 1 if_pp + ((if condition_produces_constraint c then string "sumbool_of_bool" ^^ space else empty) + ^^ parens (top_exp ctxt true c)) (string "then")) (top_exp ctxt false t)) ^^ break 1 ^^ else_pp @@ -1404,17 +1548,28 @@ let demote_as_pattern i (P_aux (_,p_annot) as pat,typ) = E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)),p_annot),e),e_ann) else (pat,typ), fun e -> e -(* Ideally we'd remove the duplication between type variables and atom - arguments, but for now we just add an equality constraint. *) +(* Add equality constraints between arguments and nexps, except in the case + that they've been merged. *) -let atom_constraint ctxt (pat, typ) = +let rec atom_constraint ctxt (pat, typ) = let typ = Env.base_typ_of (pat_env_of pat) typ in match pat, typ with | P_aux (P_id id, _), Typ_aux (Typ_app (Id_aux (Id "atom",_), - [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_) -> + [Typ_arg_aux (Typ_arg_nexp nexp,_)]),_) -> + (match nexp with + (* When the kid is mapped to the id, we don't need a constraint *) + | Nexp_aux (Nexp_var kid,_) + when (try Id.compare (KBindings.find kid ctxt.kid_id_renames) id == 0 with _ -> false) -> + None + | _ -> + Some (bquote ^^ braces (string "ArithFact" ^^ space ^^ + parens (doc_op equals (doc_id id) (doc_nexp ctxt nexp))))) + | P_aux (P_id id, _), + Typ_aux (Typ_id (Id_aux (Id "nat",_)),_) -> Some (bquote ^^ braces (string "ArithFact" ^^ space ^^ - parens (doc_op equals (doc_id id) (doc_var_lem ctxt kid)))) + parens (doc_op (string ">=") (doc_id id) (string "0")))) + | P_aux (P_typ (_,p),_), _ -> atom_constraint ctxt (p, typ) | _ -> None let all_ids pexp = @@ -1485,14 +1640,13 @@ let merge_kids_atoms pats = let gone,map,_ = List.fold_left try_eliminate (KidSet.empty, KBindings.empty, KidSet.empty) pats in gone,map -let doc_binder ctxt (P_aux (p,ann) as pat, typ) = - let env = env_of_annot ann in - let exp_typ = Env.expand_synonyms env typ in - match p with - | P_id id - | P_typ (_,P_aux (P_id id,_)) when Util.is_none (is_auto_decomposed_exist env exp_typ) -> - parens (separate space [doc_id id; colon; doc_typ ctxt typ]) - | _ -> squote ^^ parens (separate space [doc_pat ctxt true (pat, exp_typ); colon; doc_typ ctxt typ]) +let merge_var_patterns map pats = + let map,pats = List.fold_left (fun (map,pats) (pat, typ) -> + match pat with + | P_aux (P_var (P_aux (P_id id,_), TP_aux (TP_var kid,_)),ann) -> + KBindings.add kid id map, (P_aux (P_id id,ann), typ) :: pats + | _ -> map, (pat,typ)::pats) (map,[]) pats + in map, List.rev pats let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) = let (tq,typ) = Env.get_val_spec_orig id (env_of_annot annot) in @@ -1500,38 +1654,70 @@ let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) = | Typ_aux (Typ_fn (arg_typ, ret_typ, eff),_) -> arg_typ, ret_typ, eff | _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type") in + let build_ex, ret_typ = replace_atom_return_type ret_typ in let ids_to_avoid = all_ids pexp in let kids_used = tyvars_of_typquant tq in let pat,guard,exp,(l,_) = destruct_pexp pexp in let pats, bind = untuple_args_pat arg_typ pat in let pats, binds = List.split (Util.list_mapi demote_as_pattern pats) in let eliminated_kids, kid_to_arg_rename = merge_kids_atoms pats in + let kid_to_arg_rename, pats = merge_var_patterns kid_to_arg_rename pats in let kids_used = KidSet.diff kids_used eliminated_kids in let ctxt = { early_ret = contains_early_return exp; kid_renames = mk_kid_renames ids_to_avoid kids_used; kid_id_renames = kid_to_arg_rename; - bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) } in + bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ); + build_ex_return = effectful eff && build_ex; + } in (* Put the constraints after pattern matching so that any type variable that's been replaced by one of the term-level arguments is bound. *) let quantspp, constrspp = doc_typquant_items_separate ctxt braces tq in let exp = List.fold_left (fun body f -> f body) (bind exp) binds in - let patspp = separate_map space (doc_binder ctxt) pats in - let atom_constr_pp = separate_opt space (atom_constraint ctxt) pats in + let used_a_pattern = ref false in + let doc_binder (P_aux (p,ann) as pat, typ) = + let env = env_of_annot ann in + let exp_typ = Env.expand_synonyms env typ in + match p with + | P_id id + | P_typ (_,P_aux (P_id id,_)) when Util.is_none (is_auto_decomposed_exist env exp_typ) -> + parens (separate space [doc_id id; colon; doc_typ ctxt typ]) + | _ -> + (used_a_pattern := true; + squote ^^ parens (separate space [doc_pat ctxt true (pat, exp_typ); colon; doc_typ ctxt typ])) + in + let patspp = separate_map space doc_binder pats in + let atom_constrs = Util.map_filter (atom_constraint ctxt) pats in + let atom_constr_pp = separate space atom_constrs in let retpp = if effectful eff then string "M" ^^ space ^^ parens (doc_typ ctxt ret_typ) else doc_typ ctxt ret_typ in + let idpp = doc_id id in + (* Work around Coq bug 7975 about pattern binders followed by implicit arguments *) + let implicitargs = + if !used_a_pattern && List.length constrspp + List.length atom_constrs > 0 then + break 1 ^^ separate space + ([string "Arguments"; idpp;] @ + List.map (fun _ -> string "{_}") quantspp @ + List.map (fun _ -> string "_") pats @ + List.map (fun _ -> string "{_}") constrspp @ + List.map (fun _ -> string "{_}") atom_constrs) + ^^ dot + else empty + in let _ = match guard with | None -> () | _ -> raise (Reporting_basic.err_unreachable l "guarded pattern expression should have been rewritten before pretty-printing") in + let bodypp = doc_fun_body ctxt exp in + let bodypp = if effectful eff || not build_ex then bodypp else string "build_ex" ^^ parens bodypp in group (prefix 3 1 - (separate space [doc_id id; quantspp; patspp; constrspp; atom_constr_pp] ^/^ - colon ^^ space ^^ retpp ^^ coloneq) - (doc_fun_body ctxt exp ^^ dot)) + (separate space ([idpp] @ quantspp @ [patspp] @ constrspp @ [atom_constr_pp]) ^/^ + separate space [colon; retpp; coloneq]) + (bodypp ^^ dot)) ^^ implicitargs let get_id = function | [] -> failwith "FD_function with empty list" @@ -1658,9 +1844,16 @@ let doc_axiom_typschm typ_env (TypSchm_aux (TypSchm_ts (tqs,typ),l) as ts) = | _ -> parens (underscore ^^ string " : " ^^ doc_typ empty_ctxt typ) in let arg_typs_pp = separate space (List.map doc_typ' typs) in + let _, ret_ty = replace_atom_return_type ret_ty in let ret_typ_pp = doc_typ empty_ctxt ret_ty in + let ret_typ_pp = + if effectful eff + then string "M" ^^ space ^^ parens ret_typ_pp + else ret_typ_pp + in let tyvars_pp, constrs_pp = doc_typquant_items_separate empty_ctxt braces tqs in - string "forall" ^/^ tyvars_pp ^/^ arg_typs_pp ^/^ constrs_pp ^^ comma ^/^ ret_typ_pp + string "forall" ^/^ separate space tyvars_pp ^/^ + arg_typs_pp ^/^ separate space constrs_pp ^^ comma ^/^ ret_typ_pp | _ -> doc_typschm empty_ctxt true ts let doc_val_spec unimplemented (VS_aux (VS_val_spec(tys,id,_,_),ann)) = @@ -1800,6 +1993,8 @@ try (fun lib -> separate space [string "Require Import";string lib] ^^ dot) defs_modules;hardline; string "Import ListNotations."; hardline; + string "Open Scope string."; hardline; + string "Open Scope bool."; hardline; (* Put the body into a Section so that we can define some values with Let to put them into the local context, where tactics can see them *) string "Section Content."; diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index c872d420..9897bb7c 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -334,14 +334,14 @@ let replace_typ_size ctxt env (Typ_aux (t,a)) = let mk_typ nexp = Some (Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp nexp,Parse_ast.Unknown);ord;typ']),a)) in - match Type_check.solve env size with - | Some n -> mk_typ (nconstant n) - | None -> - let is_equal nexp = - prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown)) - in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with - | nexp -> mk_typ nexp - | exception Not_found -> None + let is_equal nexp = + prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown)) + in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with + | nexp -> mk_typ nexp + | exception Not_found -> + match Type_check.solve env size with + | Some n -> mk_typ (nconstant n) + | None -> None end | _ -> None diff --git a/src/process_file.ml b/src/process_file.ml index 9603e986..c3e1b510 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -398,12 +398,12 @@ let rewrite rewriters defs = let rewrite_ast = rewrite [("initial", Rewriter.rewrite_defs)] let rewrite_undefined bitvectors = rewrite [("undefined", fun x -> Rewrites.rewrite_undefined bitvectors x)] let rewrite_ast_lem = rewrite Rewrites.rewrite_defs_lem -let rewrite_ast_coq = rewrite Rewrites.rewrite_defs_lem +let rewrite_ast_coq = rewrite Rewrites.rewrite_defs_coq let rewrite_ast_ocaml = rewrite Rewrites.rewrite_defs_ocaml let rewrite_ast_c ast = ast |> rewrite Rewrites.rewrite_defs_c - |> Constant_fold.rewrite_constant_function_calls + |> rewrite [("constant_fold", Constant_fold.rewrite_constant_function_calls)] let rewrite_ast_interpreter = rewrite Rewrites.rewrite_defs_interpreter let rewrite_ast_check = rewrite Rewrites.rewrite_defs_check diff --git a/src/rewrites.ml b/src/rewrites.ml index 214ca571..246a2670 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -1168,6 +1168,25 @@ let case_exp e t cs = (* let efr = union_effs (List.map effect_of_pexp ps) in *) fix_eff_exp (annot_exp (E_case (e,ps)) l env t) +(* Rewrite guarded patterns into a combination of if-expressions and + unguarded pattern matches + + Strategy: + - Split clauses into groups where the first pattern subsumes all the + following ones + - Translate the groups in reverse order, using the next group as a + fall-through target, if there is one + - Within a group, + - translate the sequence of clauses to an if-then-else cascade using the + guards as long as the patterns are equivalent modulo substitution, or + - recursively translate the remaining clauses to a pattern match if + there is a difference in the patterns. + + TODO: Compare this more closely with the algorithm in the CPP'18 paper of + Spector-Zabusky et al, who seem to use the opposite grouping and merging + strategy to ours: group *mutually exclusive* clauses, and try to merge them + into a pattern match first instead of an if-then-else cascade. +*) let rewrite_guarded_clauses l cs = let rec group fallthrough clauses = let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in @@ -2285,6 +2304,7 @@ let rewrite_type_def_typs rw_typ rw_typquant rw_typschm (TD_aux (td, annot)) = let rewrite_dec_spec_typs rw_typ (DEC_aux (ds, annot)) = match ds with | DEC_reg (typ, id) -> DEC_aux (DEC_reg (rw_typ typ, id), annot) + | DEC_config (id, typ, exp) -> DEC_aux (DEC_config (id, rw_typ typ, exp), annot) | _ -> assert false (* Remove overload definitions and cast val specs from the @@ -2606,7 +2626,7 @@ let rewrite_defs_letbind_effects = match lexp_aux with | LEXP_id _ -> k lexp | LEXP_deref exp -> - n_exp exp (fun exp -> + n_exp_name exp (fun exp -> k (fix_eff_lexp (LEXP_aux (LEXP_deref exp, annot)))) | LEXP_memory (id,es) -> n_exp_nameL es (fun es -> @@ -3305,8 +3325,15 @@ let rewrite_defs_mapping_patterns = in pexp_rewriters rewrite_pexp +let rewrite_lit_lem (L_aux (lit, _)) = match lit with + | L_num _ | L_string _ | L_hex _ | L_bin _ | L_real _ -> true + | _ -> false + +let rewrite_no_strings (L_aux (lit, _)) = match lit with + | L_string _ -> false + | _ -> true -let rewrite_defs_pat_lits = +let rewrite_defs_pat_lits rewrite_lit = let rewrite_pexp (Pat_aux (pexp_aux, annot) as pexp) = let guards = ref [] in let counter = ref 0 in @@ -3314,7 +3341,7 @@ let rewrite_defs_pat_lits = let rewrite_pat = function (* HACK: ignore strings for now *) | P_lit (L_aux (L_string _, _)) as p_aux, p_annot -> P_aux (p_aux, p_annot) - | P_lit lit, p_annot -> + | P_lit lit, p_annot when rewrite_lit lit -> let env = env_of_annot p_annot in let typ = typ_of_annot p_annot in let id = mk_id ("p" ^ string_of_int !counter ^ "#") in @@ -3513,11 +3540,12 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = let eff = union_eff_exps [c;e1;e2] in let v = E_aux (E_if (c,e1,e2), (gen_loc el, Some (env, typ, eff))) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) - | E_case (e1,ps) -> - (* after rewrite_defs_letbind_effects e1 needs no rewriting *) + | E_case (e1,ps) | E_try (e1, ps) -> + let is_case = match expaux with E_case _ -> true | _ -> false in let vars, varpats = - ps - |> List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e) + (* for E_case, e1 needs no rewriting after rewrite_defs_letbind_effects *) + ((if is_case then [] else [e1]) @ + List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e) ps) |> List.map find_updated_vars |> List.fold_left IdSet.union IdSet.empty |> IdSet.inter used_vars @@ -3528,8 +3556,10 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = Pat_aux (Pat_exp (p,rewrite_var_updates e),a) | Pat_aux (Pat_when (p,g,e),a) -> Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in - Same_vars (E_aux (E_case (e1,ps),annot)) + let expaux = if is_case then E_case (e1, ps) else E_try (e1, ps) in + Same_vars (E_aux (expaux, annot)) else + let e1 = if is_case then e1 else rewrite_var_updates (add_vars overwrite e1 vars) in let rewrite_pexp (Pat_aux (pexp, (l, _))) = match pexp with | Pat_exp (pat, exp) -> let exp = rewrite_var_updates (add_vars overwrite exp vars) in @@ -3538,10 +3568,12 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) = | Pat_when _ -> raise (Reporting_basic.err_unreachable l "Guarded patterns should have been rewritten already") in + let ps = List.map rewrite_pexp ps in + let expaux = if is_case then E_case (e1, ps) else E_try (e1, ps) in let typ = match ps with | Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_) :: _ -> typ_of first | _ -> unit_typ in - let v = fix_eff_exp (annot_exp (E_case (e1, List.map rewrite_pexp ps)) pl env typ) in + let v = fix_eff_exp (annot_exp expaux pl env typ) in Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats)) | E_assign (lexp,vexp) -> let mk_id_pat id = match Env.lookup_id id env with @@ -3969,6 +4001,263 @@ let rewrite_defs_realise_mappings (Defs defs) = in Defs (List.map rewrite_def defs |> List.flatten) + +(* Rewrite to make all pattern matches in Coq output exhaustive. + Assumes that guards, vector patterns, etc have been rewritten already. *) + +let opt_coq_warn_nonexhaustive = ref false + +module MakeExhaustive = +struct + +type rlit = + | RL_unit + | RL_zero + | RL_one + | RL_true + | RL_false + | RL_inf + +let string_of_rlit = function + | RL_unit -> "()" + | RL_zero -> "bitzero" + | RL_one -> "bitone" + | RL_true -> "true" + | RL_false -> "false" + | RL_inf -> "..." + +let rlit_of_lit (L_aux (l,_)) = + match l with + | L_unit -> RL_unit + | L_zero -> RL_zero + | L_one -> RL_one + | L_true -> RL_true + | L_false -> RL_false + | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> RL_inf + | L_undef -> assert false + +let inv_rlit_of_lit (L_aux (l,_)) = + match l with + | L_unit -> [] + | L_zero -> [RL_one] + | L_one -> [RL_zero] + | L_true -> [RL_false] + | L_false -> [RL_true] + | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> [RL_inf] + | L_undef -> assert false + +type residual_pattern = + | RP_any + | RP_lit of rlit + | RP_enum of id + | RP_app of id * residual_pattern list + | RP_tup of residual_pattern list + | RP_nil + | RP_cons of residual_pattern * residual_pattern + +let rec string_of_rp = function + | RP_any -> "_" + | RP_lit rlit -> string_of_rlit rlit + | RP_enum id -> string_of_id id + | RP_app (f,args) -> string_of_id f ^ "(" ^ String.concat "," (List.map string_of_rp args) ^ ")" + | RP_tup rps -> "(" ^ String.concat "," (List.map string_of_rp rps) ^ ")" + | RP_nil -> "[| |]" + | RP_cons (rp1,rp2) -> string_of_rp rp1 ^ "::" ^ string_of_rp rp2 + +type ctx = { + env : Env.t; + enum_to_rest: (residual_pattern list) Bindings.t; + constructor_to_rest: (residual_pattern list) Bindings.t +} + +let make_enum_mappings ids m = + IdSet.fold (fun id m -> + Bindings.add id + (List.map (fun e -> RP_enum e) (IdSet.elements (IdSet.remove id ids))) m) + ids + m + +let make_cstr_mappings env ids m = + let ids = IdSet.elements ids in + let constructors = List.map + (fun id -> + let _,ty = Env.get_val_spec id env in + let args = match ty with + | Typ_aux (Typ_fn (Typ_aux (Typ_tup tys,_),_,_),_) -> List.map (fun _ -> RP_any) tys + | _ -> [RP_any] + in RP_app (id,args)) ids in + let rec aux ids acc l = + match ids, l with + | [], [] -> m + | id::ids, rp::t -> + let m = aux ids (acc@[rp]) t in + Bindings.add id (acc@t) m + | _ -> assert false + in aux ids [] constructors + +let ctx_from_pattern_completeness_ctx env = + let ctx = Env.pattern_completeness_ctx env in + { env = env; + enum_to_rest = Bindings.fold (fun _ ids m -> make_enum_mappings ids m) + ctx.Pattern_completeness.enums Bindings.empty; + constructor_to_rest = Bindings.fold (fun _ ids m -> make_cstr_mappings env ids m) + ctx.Pattern_completeness.variants Bindings.empty + } + +let printprefix = ref " " + +let rec remove_clause_from_pattern ctx (P_aux (rm_pat,ann)) res_pat = + let subpats rm_pats res_pats = + let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in + let rec aux acc fixed residual = + match fixed, residual with + | [], [] -> [] + | (fh::ft), (rh::rt) -> + let rt' = aux (acc@[fh]) ft rt in + let newr = List.map (fun x -> acc @ (x::ft)) rh in + newr @ rt' + | _,_ -> assert false (* impossible because we managed map2 above *) + in aux [] res_pats res_pats' + in + let inconsistent () = + raise (Reporting_basic.err_unreachable (fst ann) + ("Inconsistency during exhaustiveness analysis with " ^ + string_of_rp res_pat)) + in + (*let _ = print_endline (!printprefix ^ "pat: " ^string_of_pat (P_aux (rm_pat,ann))) in + let _ = print_endline (!printprefix ^ "res_pat: " ^string_of_rp res_pat) in + let _ = printprefix := " " ^ !printprefix in*) + let rp' = + match rm_pat with + | P_wild -> [] + | P_id id when (match Env.lookup_id id ctx.env with Unbound | Local _ -> true | _ -> false) -> [] + | P_lit lit -> + (match res_pat with + | RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit) + | RP_lit RL_inf -> [res_pat] + | RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat] + | _ -> inconsistent ()) + | P_as (p,_) + | P_typ (_,p) + | P_var (p,_) + -> remove_clause_from_pattern ctx p res_pat + | P_id id -> + (match Env.lookup_id id ctx.env with + | Enum enum -> + (match res_pat with + | RP_any -> Bindings.find id ctx.enum_to_rest + | RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat] + | _ -> inconsistent ()) + | _ -> assert false) + | P_tup rm_pats -> + let previous_res_pats = + match res_pat with + | RP_tup res_pats -> res_pats + | RP_any -> List.map (fun _ -> RP_any) rm_pats + | _ -> inconsistent () + in + let res_pats' = subpats rm_pats previous_res_pats in + List.map (fun rps -> RP_tup rps) res_pats' + | P_app (id,args) -> + (match res_pat with + | RP_app (id',residual_args) -> + if Id.compare id id' == 0 then + let res_pats' = subpats args residual_args in + List.map (fun rps -> RP_app (id,rps)) res_pats' + else [res_pat] + | RP_any -> + let res_args = subpats args (List.map (fun _ -> RP_any) args) in + (List.map (fun l -> (RP_app (id,l))) res_args) @ + (Bindings.find id ctx.constructor_to_rest) + | _ -> inconsistent () + ) + | P_list ps -> + (match ps with + | p1::ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1,P_aux (P_list ptl,ann)),ann)) res_pat + | [] -> + match res_pat with + | RP_any -> [RP_cons (RP_any,RP_any)] + | RP_cons _ -> [res_pat] + | RP_nil -> [] + | _ -> inconsistent ()) + | P_cons (p1,p2) -> begin + let rp',rps = + match res_pat with + | RP_cons (rp1,rp2) -> [], Some [rp1;rp2] + | RP_any -> [RP_nil], Some [RP_any;RP_any] + | RP_nil -> [RP_nil], None + | _ -> inconsistent () + in + match rps with + | None -> rp' + | Some rps -> + let res_pats = subpats [p1;p2] rps in + rp' @ List.map (function [rp1;rp2] -> RP_cons (rp1,rp2) | _ -> assert false) res_pats + end + | P_record _ -> + raise (Reporting_basic.err_unreachable (fst ann) + "Record pattern not supported") + | P_vector _ + | P_vector_concat _ + | P_string_append _ -> + raise (Reporting_basic.err_unreachable (fst ann) + "Found pattern that should have been rewritten away in earlier stage") + + (*in let _ = printprefix := String.sub (!printprefix) 0 (String.length !printprefix - 2) + in let _ = print_endline (!printprefix ^ "res_pats': " ^ String.concat "; " (List.map string_of_rp rp'))*) + in rp' + +let process_pexp env = + let ctx = ctx_from_pattern_completeness_ctx env in + fun rps patexp -> + (*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in + let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*) + match patexp with + | Pat_aux (Pat_exp (p,_),_) -> + List.concat (List.map (remove_clause_from_pattern ctx p) rps) + | Pat_aux (Pat_when _,(l,_)) -> + raise (Reporting_basic.err_unreachable l + "Guarded pattern should have been rewritten away") + +let rewrite_case (e,ann) = + match e with + | E_case (e1,cases) -> + begin + let env = env_of_annot ann in + let rps = List.fold_left (process_pexp env) [RP_any] cases in + match rps with + | [] -> E_aux (E_case (e1,cases),ann) + | (example::_) -> + + let _ = + if !opt_coq_warn_nonexhaustive + then Reporting_basic.print_err false false + (fst ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) in + + let l = Parse_ast.Generated Parse_ast.Unknown in + let p = P_aux (P_wild, (l, None)) in + let ann' = Some (env, typ_of_annot ann, mk_effect [BE_escape]) in + (* TODO: use an expression that specifically indicates a failed pattern match *) + let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)),(l,None))),(l,ann')) in + E_aux (E_case (e1,cases@[Pat_aux (Pat_exp (p,b),(l,None))]),ann) + end + | _ -> E_aux (e,ann) + +let rewrite = + let alg = { id_exp_alg with e_aux = rewrite_case } in + rewrite_defs_base + { rewrite_exp = (fun _ -> fold_exp alg) + ; rewrite_pat = rewrite_pat + ; rewrite_let = rewrite_let + ; rewrite_lexp = rewrite_lexp + ; rewrite_fun = rewrite_fun + ; rewrite_def = rewrite_def + ; rewrite_defs = rewrite_defs_base + } + + +end + let recheck_defs defs = fst (Type_error.check initial_env defs) let remove_mapping_valspecs (Defs defs) = @@ -3985,7 +4274,7 @@ let rewrite_defs_lem = [ ("remove_mapping_valspecs", remove_mapping_valspecs); ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_patterns); - ("pat_lits", rewrite_defs_pat_lits); + ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem); ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); @@ -4017,13 +4306,52 @@ let rewrite_defs_lem = [ ("recheck_defs", recheck_defs) ] +let rewrite_defs_coq = [ + ("realise_mappings", rewrite_defs_realise_mappings); + ("remove_mapping_valspecs", remove_mapping_valspecs); + ("pat_string_append", rewrite_defs_pat_string_append); + ("mapping_builtins", rewrite_defs_mapping_patterns); + ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem); + ("vector_concat_assignments", rewrite_vector_concat_assignments); + ("tuple_assignments", rewrite_tuple_assignments); + ("simple_assignments", rewrite_simple_assignments); + ("remove_vector_concat", rewrite_defs_remove_vector_concat); + ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats); + ("remove_numeral_pats", rewrite_defs_remove_numeral_pats); + ("guarded_pats", rewrite_defs_guarded_pats); + ("bitvector_exps", rewrite_bitvector_exps); + (* ("register_ref_writes", rewrite_register_ref_writes); *) + ("nexp_ids", rewrite_defs_nexp_ids); + ("fix_val_specs", rewrite_fix_val_specs); + ("split_execute", rewrite_split_fun_constr_pats "execute"); + ("recheck_defs", recheck_defs); + ("exp_lift_assign", rewrite_defs_exp_lift_assign); + (* ("constraint", rewrite_constraint); *) + (* ("remove_assert", rewrite_defs_remove_assert); *) + ("top_sort_defs", top_sort_defs); + ("trivial_sizeof", rewrite_trivial_sizeof); + ("sizeof", rewrite_sizeof); + ("early_return", rewrite_defs_early_return); + ("make_cases_exhaustive", MakeExhaustive.rewrite); + ("fix_val_specs", rewrite_fix_val_specs); + ("recheck_defs", recheck_defs); + ("remove_blocks", rewrite_defs_remove_blocks); + ("letbind_effects", rewrite_defs_letbind_effects); + ("remove_e_assign", rewrite_defs_remove_e_assign); + ("internal_lets", rewrite_defs_internal_lets); + ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds); + ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns); + ("merge function clauses", merge_funcls); + ("recheck_defs", recheck_defs) + ] + let rewrite_defs_ocaml = [ (* ("undefined", rewrite_undefined); *) ("no_effect_check", (fun defs -> opt_no_effects := true; defs)); ("realise_mappings", rewrite_defs_realise_mappings); ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_patterns); - ("pat_lits", rewrite_defs_pat_lits); + ("pat_lits", rewrite_defs_pat_lits rewrite_no_strings); ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); @@ -4045,7 +4373,7 @@ let rewrite_defs_c = [ ("realise_mappings", rewrite_defs_realise_mappings); ("pat_string_append", rewrite_defs_pat_string_append); ("mapping_builtins", rewrite_defs_mapping_patterns); - ("pat_lits", rewrite_defs_pat_lits); + ("pat_lits", rewrite_defs_pat_lits rewrite_no_strings); ("vector_concat_assignments", rewrite_vector_concat_assignments); ("tuple_assignments", rewrite_tuple_assignments); ("simple_assignments", rewrite_simple_assignments); diff --git a/src/rewrites.mli b/src/rewrites.mli index 70cb75af..7d6bc0b2 100644 --- a/src/rewrites.mli +++ b/src/rewrites.mli @@ -66,6 +66,13 @@ val rewrite_defs_interpreter : (string * (tannot defs -> tannot defs)) list (* Perform rewrites to exclude AST nodes not supported for lem out*) val rewrite_defs_lem : (string * (tannot defs -> tannot defs)) list +(* Perform rewrites to exclude AST nodes not supported for coq out*) +val rewrite_defs_coq : (string * (tannot defs -> tannot defs)) list + +(* Warn about matches where we add a default case for Coq because they're not + exhaustive *) +val opt_coq_warn_nonexhaustive : bool ref + (* Perform rewrites to exclude AST nodes not supported for C compilation *) val rewrite_defs_c : (string * (tannot defs -> tannot defs)) list diff --git a/src/sail.ml b/src/sail.ml index 944eb9ff..5b7c9dbf 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -118,6 +118,9 @@ let options = Arg.align ([ ( "-Oconstant_fold", Arg.Set Constant_fold.optimize_constant_fold, " Apply constant folding optimizations"); + ( "-static", + Arg.Set C_backend.opt_static, + " Make generated C functions static"); ( "-trace", Arg.Tuple [Arg.Set C_backend.opt_trace; Arg.Set Ocaml_backend.opt_trace_ocaml], " Instrument ouput with tracing"); @@ -145,6 +148,9 @@ let options = Arg.align ([ ( "-dcoq_undef_axioms", Arg.Set Pretty_print_coq.opt_undef_axioms, "Generate axioms for functions that are declared but not defined"); + ( "-dcoq_warn_nonex", + Arg.Set Rewrites.opt_coq_warn_nonexhaustive, + "Generate warnings for non-exhaustive pattern matches in the Coq backend"); ( "-latex_prefix", Arg.String (fun prefix -> Latex.opt_prefix_latex := prefix), " set a custom prefix for generated latex command (default sail)"); diff --git a/src/sail_lib.ml b/src/sail_lib.ml index ab621342..16b1d3cc 100644 --- a/src/sail_lib.ml +++ b/src/sail_lib.ml @@ -419,7 +419,7 @@ let get_mem_page p = try Mem.find p !mem_pages with Not_found -> - let new_page = Bytes.create page_size_bytes in + let new_page = Bytes.make page_size_bytes '\000' in mem_pages := Mem.add p new_page !mem_pages; new_page @@ -457,7 +457,7 @@ let write_ram' (data_size, addr, data) = end let write_ram (addr_size, data_size, hex_ram, addr, data) = - write_ram' (data_size, uint addr, data) + write_ram' (data_size, uint addr, data); true let wram addr byte = let bytes = Bytes.make 1 (char_of_int byte) in diff --git a/src/state.ml b/src/state.ml index c591f753..245d450c 100644 --- a/src/state.ml +++ b/src/state.ml @@ -367,15 +367,15 @@ let generate_isa_lemmas mwords (Defs defs : tannot defs) = let id' = remove_trailing_underscores id in separate_map hardline string [ "lemma liftS_read_reg_" ^ id ^ "[liftState_simp]:"; - " \"liftS (read_reg " ^ id ^ "_ref) = readS (" ^ id' ^ " \\<circ> regstate)\""; + " \"\\<lbrakk>read_reg " ^ id ^ "_ref\\<rbrakk>\\<^sub>S = readS (" ^ id' ^ " \\<circ> regstate)\""; " by (auto simp: liftState_read_reg_readS register_defs)"; ""; "lemma liftS_write_reg_" ^ id ^ "[liftState_simp]:"; - " \"liftS (write_reg " ^ id ^ "_ref v) = updateS (regstate_update (" ^ id' ^ "_update (\\<lambda>_. v)))\""; + " \"\\<lbrakk>write_reg " ^ id ^ "_ref v\\<rbrakk>\\<^sub>S = updateS (regstate_update (" ^ id' ^ "_update (\\<lambda>_. v)))\""; " by (auto simp: liftState_write_reg_updateS register_defs)" ] in - string "abbreviation \"liftS \\<equiv> liftState (get_regval, set_regval)\"" ^^ + string "abbreviation liftS (\"\\<lbrakk>_\\<rbrakk>\\<^sub>S\") where \"liftS \\<equiv> liftState (get_regval, set_regval)\"" ^^ hardline ^^ hardline ^^ register_defs ^^ hardline ^^ hardline ^^ @@ -411,7 +411,7 @@ let rec regval_convs_coq (Typ_aux (t, _) as typ) = match t with let size = string_of_nexp (nexp_simp size) in let is_inc = if is_order_inc ord then "true" else "false" in let etyp_of, of_etyp = regval_convs_coq etyp in - "(fun v => vector_of_regval " ^ etyp_of ^ " v)", + "(fun v => vector_of_regval " ^ size ^ " " ^ etyp_of ^ " v)", "(fun v => regval_of_vector " ^ of_etyp ^ " " ^ size ^ " " ^ is_inc ^ " v)" | Typ_app (id, [Typ_arg_aux (Typ_arg_typ etyp, _)]) when string_of_id id = "list" -> @@ -430,12 +430,12 @@ let rec regval_convs_coq (Typ_aux (t, _) as typ) = match t with let register_refs_coq registers = let generic_convs = separate_map hardline string [ - "Definition vector_of_regval {a} (of_regval : register_value -> option a) (rv : register_value) : option (list a) := match rv with"; - " | Regval_vector (_, _, v) => just_list (List.map of_regval v)"; + "Definition vector_of_regval {a} n (of_regval : register_value -> option a) (rv : register_value) : option (vec a n) := match rv with"; + " | Regval_vector (n', _, v) => if n =? n' then map_bind (vec_of_list n) (just_list (List.map of_regval v)) else None"; " | _ => None"; "end."; ""; - "Definition regval_of_vector {a} (regval_of : a -> register_value) (size : Z) (is_inc : bool) (xs : list a) : register_value := Regval_vector (size, is_inc, List.map regval_of xs)."; + "Definition regval_of_vector {a} (regval_of : a -> register_value) (size : Z) (is_inc : bool) (xs : vec a size) : register_value := Regval_vector (size, is_inc, List.map regval_of (list_of_vec xs))."; ""; "Definition list_of_regval {a} (of_regval : register_value -> option a) (rv : register_value) : option (list a) := match rv with"; " | Regval_list v => just_list (List.map of_regval v)"; diff --git a/src/test/lib/run_test_interp.ml b/src/test/lib/run_test_interp.ml deleted file mode 100644 index 5f2ace1b..00000000 --- a/src/test/lib/run_test_interp.ml +++ /dev/null @@ -1,51 +0,0 @@ -(**************************************************************************) -(* Sail *) -(* *) -(* Copyright (c) 2013-2017 *) -(* Kathyrn Gray *) -(* Shaked Flur *) -(* Stephen Kell *) -(* Gabriel Kerneis *) -(* Robert Norton-Wright *) -(* Christopher Pulte *) -(* Peter Sewell *) -(* *) -(* All rights reserved. *) -(* *) -(* This software was developed by the University of Cambridge Computer *) -(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *) -(* (REMS) project, funded by EPSRC grant EP/K008528/1. *) -(* *) -(* Redistribution and use in source and binary forms, with or without *) -(* modification, are permitted provided that the following conditions *) -(* are met: *) -(* 1. Redistributions of source code must retain the above copyright *) -(* notice, this list of conditions and the following disclaimer. *) -(* 2. Redistributions in binary form must reproduce the above copyright *) -(* notice, this list of conditions and the following disclaimer in *) -(* the documentation and/or other materials provided with the *) -(* distribution. *) -(* *) -(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) -(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) -(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) -(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) -(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) -(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) -(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) -(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) -(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) -(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) -(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) -(* SUCH DAMAGE. *) -(**************************************************************************) - -open Interp_interface ;; -open Interp_inter_imp ;; -open Sail_impl_base ;; - -let doit () = - let context = build_context false Test_lem_ast.defs [] [] [] [] [] [] [] None [] in - translate_address context E_little_endian "run" None (address_of_integer (Nat_big_num.of_int 0));; - -doit () ;; diff --git a/src/type_check.ml b/src/type_check.ml index 810cd265..c73c7000 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -127,28 +127,11 @@ let is_unknown_type = function | (Typ_aux (Typ_internal_unknown, _)) -> true | _ -> false -(* An index_sort is a more general form of range type: it can either - be IS_int, which represents every natural number, or some set of - natural numbers given by an IS_prop expression of the form - {'n. f('n) <= g('n) /\ ...} *) -type index_sort = - | IS_int - | IS_prop of kid * (nexp * nexp) list - -let string_of_index_sort = function - | IS_int -> "INT" - | IS_prop (kid, constraints) -> - "{" ^ string_of_kid kid ^ " | " - ^ string_of_list " & " (fun (x, y) -> string_of_nexp x ^ " <= " ^ string_of_nexp y) constraints - ^ "}" - let is_atom (Typ_aux (typ_aux, _)) = match typ_aux with | Typ_app (f, [_]) when string_of_id f = "atom" -> true | _ -> false - - let rec strip_id = function | Id_aux (Id x, _) -> Id_aux (Id x, Parse_ast.Unknown) | Id_aux (DeIid x, _) -> Id_aux (DeIid x, Parse_ast.Unknown) @@ -367,6 +350,7 @@ module Env : sig val is_mapping : id -> t -> bool val add_record : id -> typquant -> (typ * id) list -> t -> t val is_record : id -> t -> bool + val get_record : id -> t -> typquant * (typ * id) list val get_accessor_fn : id -> id -> t -> typquant * typ val get_accessor : id -> id -> t -> typquant * typ * typ * effect val add_local : id -> mut * typ -> t -> t @@ -929,6 +913,8 @@ end = struct let is_record id env = Bindings.mem id env.records + let get_record id env = Bindings.find id env.records + let add_record id typq fields env = if bound_typ_id env id then typ_error (id_loc id) ("Cannot create record " ^ string_of_id id ^ ", type name is already bound") @@ -3311,6 +3297,12 @@ and instantiation_of (E_aux (exp_aux, (l, _)) as exp) = | E_app (f, xs) -> snd (infer_funapp' l (Env.no_casts env) f (Env.get_val_spec f env) (List.map strip_exp xs) (Some (typ_of exp))) | _ -> invalid_arg ("instantiation_of expected application, got " ^ string_of_exp exp) +and instantiation_of_without_type (E_aux (exp_aux, (l, _)) as exp) = + let env = env_of exp in + match exp_aux with + | E_app (f, xs) -> snd (infer_funapp' l (Env.no_casts env) f (Env.get_val_spec f env) (List.map strip_exp xs) None) + | _ -> invalid_arg ("instantiation_of expected application, got " ^ string_of_exp exp) + and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ = let annot_exp exp typ eff = E_aux (exp, (l, Some (env, typ, eff))) in let switch_annot env typ = function diff --git a/src/type_check.mli b/src/type_check.mli index 24443f2e..286c2be9 100644 --- a/src/type_check.mli +++ b/src/type_check.mli @@ -133,6 +133,9 @@ module Env : sig val is_record : id -> t -> bool + (** Returns record quantifiers and fields *) + val get_record : id -> t -> typquant * (typ * id) list + (** Return type is: quantifier, argument type, return type, effect *) val get_accessor : id -> id -> t -> typquant * typ * typ * effect @@ -191,6 +194,7 @@ module Env : sig val empty : t + val pattern_completeness_ctx : t -> Pattern_completeness.ctx end (** Push all the type variables and constraints from a typquant into @@ -349,6 +353,10 @@ val alpha_equivalent : Env.t -> typ -> typ -> bool (** Throws Invalid_argument if the argument is not a E_app expression *) val instantiation_of : tannot exp -> uvar KBindings.t +(** Doesn't use the type of the expression when calculating instantiations. + May fail if the arguments aren't sufficient to calculate all unifiers. *) +val instantiation_of_without_type : tannot exp -> uvar KBindings.t + (* Type variable instantiations that inference will extract from constraints *) val instantiate_simple_equations : quant_item list -> uvar KBindings.t diff --git a/src/value.ml b/src/value.ml index 41b52720..dccb216e 100644 --- a/src/value.ml +++ b/src/value.ml @@ -406,8 +406,8 @@ let value_read_ram = function let value_write_ram = function | [v1; v2; v3; v4; v5] -> - Sail_lib.write_ram (coerce_int v1, coerce_int v2, coerce_bv v3, coerce_bv v4, coerce_bv v5); - V_unit + let b = Sail_lib.write_ram (coerce_int v1, coerce_int v2, coerce_bv v3, coerce_bv v4, coerce_bv v5) in + V_bool(b) | _ -> failwith "value write_ram" let value_load_raw = function @@ -561,6 +561,7 @@ let primops = ("write_ram", value_write_ram); ("trace_memory_read", fun _ -> V_unit); ("trace_memory_write", fun _ -> V_unit); + ("get_time_ns", fun _ -> V_int (Sail_lib.get_time_ns())); ("load_raw", value_load_raw); ("to_real", value_to_real); ("eq_real", value_eq_real); diff --git a/src/value2.lem b/src/value2.lem index 33416503..e8a8262a 100644 --- a/src/value2.lem +++ b/src/value2.lem @@ -1,4 +1,4 @@ -(**************************************************************************) +(*========================================================================*) (* Sail *) (* *) (* Copyright (c) 2013-2017 *) @@ -46,7 +46,7 @@ (* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) (* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) (* SUCH DAMAGE. *) -(**************************************************************************) +(*========================================================================*) open import Pervasives open import Assert_extra @@ -70,17 +70,13 @@ type vl = | V_record of list (string * vl) | V_null (* Used for unitialized values and null pointers in C compilation *) -let primops extern args = - match (extern, args) with - | ("and_bool", [V_bool b1; V_bool b2]) -> V_bool (b1 && b2) - | ("and_bool", [V_nondet; V_bool false]) -> V_bool false - | ("and_bool", [V_bool false; V_nondet]) -> V_bool false - | ("and_bool", _) -> V_nondet - | ("or_bool", [V_bool b1; V_bool b2]) -> V_bool (b1 || b2) - | ("or_bool", [V_nondet; V_bool true]) -> V_bool true - | ("or_bool", [V_bool true; V_nondet]) -> V_bool true - | ("or_bool", _) -> V_nondet +let value_int_op_int op = function + | [V_int v1; V_int v2] -> V_int (op v1 v2) + | _ -> V_null +end - | _ -> failwith ("Unimplemented primitive operation " ^ extern) - end +let value_bool_op_int op = function + | [V_int v1; V_int v2] -> V_bool (op v1 v2) + | _ -> V_null +end |
