summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/c_backend.ml185
-rw-r--r--src/constant_fold.ml101
-rw-r--r--src/gen_lib/sail2_operators_bitlists.lem68
-rw-r--r--src/gen_lib/sail2_operators_mwords.lem74
-rw-r--r--src/gen_lib/sail2_prompt.lem32
-rw-r--r--src/gen_lib/sail2_prompt_monad.lem1
-rw-r--r--src/gen_lib/sail2_state.lem32
-rw-r--r--src/gen_lib/sail2_state_monad.lem18
-rw-r--r--src/gen_lib/sail2_string.lem2
-rw-r--r--src/gen_lib/sail2_values.lem4
-rw-r--r--src/interpreter.ml2
-rw-r--r--src/isail.ml25
-rw-r--r--src/lem_interp/sail2_instr_kinds.lem18
-rw-r--r--src/ocaml_backend.ml15
-rw-r--r--src/pretty_print_coq.ml361
-rw-r--r--src/pretty_print_lem.ml16
-rw-r--r--src/process_file.ml4
-rw-r--r--src/rewrites.ml352
-rw-r--r--src/rewrites.mli7
-rw-r--r--src/sail.ml6
-rw-r--r--src/sail_lib.ml4
-rw-r--r--src/state.ml14
-rw-r--r--src/test/lib/run_test_interp.ml51
-rw-r--r--src/type_check.ml26
-rw-r--r--src/type_check.mli8
-rw-r--r--src/value.ml5
-rw-r--r--src/value2.lem24
27 files changed, 1036 insertions, 419 deletions
diff --git a/src/c_backend.ml b/src/c_backend.ml
index 85c2eeb3..433e3d85 100644
--- a/src/c_backend.ml
+++ b/src/c_backend.ml
@@ -63,6 +63,7 @@ module Big_int = Nat_big_num
let c_verbosity = ref 0
let opt_ddump_flow_graphs = ref false
let opt_trace = ref false
+let opt_static = ref false
(* Optimization flags *)
let optimize_primops = ref false
@@ -222,7 +223,7 @@ let rec is_stack_ctyp ctyp = match ctyp with
let is_stack_typ ctx typ = is_stack_ctyp (ctyp_of_typ ctx typ)
let ctor_bindings = List.fold_left (fun map (id, ctyp) -> Bindings.add id ctyp map) Bindings.empty
-
+
(**************************************************************************)
(* 3. Optimization of primitives and literals *)
(**************************************************************************)
@@ -778,7 +779,7 @@ let rec compile_aval ctx = function
in
[idecl vector_ctyp gs;
iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init") [(F_lit (V_int (Big_int.of_int len)), CT_int64)]]
- @ List.concat (List.mapi aval_set avals),
+ @ List.concat (List.mapi aval_set (if direction then List.rev avals else avals)),
(F_id gs, vector_ctyp),
[iclear vector_ctyp gs]
@@ -927,7 +928,10 @@ let label str =
let pointer_assign ctyp1 ctyp2 =
match ctyp1 with
| CT_ref ctyp1 when ctyp_equal ctyp1 ctyp2 -> true
- | CT_ref ctyp1 -> c_error "Incompatible type in pointer assignment"
+ | CT_ref ctyp1 ->
+ c_error (Printf.sprintf "Incompatible type in pointer assignment between %s and %s"
+ (string_of_ctyp ctyp1)
+ (string_of_ctyp ctyp2))
| _ -> false
let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
@@ -1050,11 +1054,16 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
(* FIXME: AE_record_update could be AV_record_update - would reduce some copying. *)
| AE_record_update (aval, fields, typ) ->
let ctyp = ctyp_of_typ ctx typ in
+ let ctors = match ctyp with
+ | CT_struct (_, ctors) -> List.fold_left (fun m (k, v) -> Bindings.add k v m) Bindings.empty ctors
+ | _ -> c_error "Cannot perform record update for non-record type"
+ in
let gs = gensym () in
let compile_fields (id, aval) =
let field_setup, cval, field_cleanup = compile_aval ctx aval in
field_setup
- @ [icopy (CL_field (gs, string_of_id id, cval_ctyp cval)) cval]
+ @ [icomment (string_of_ctyp ctyp)]
+ @ [icopy (CL_field (gs, string_of_id id, Bindings.find id ctors)) cval]
@ field_cleanup
in
let setup, cval, cleanup = compile_aval ctx aval in
@@ -1227,6 +1236,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
in
let loop_start_label = label "for_start_" in
+ let loop_end_label = label "for_end_" in
let body_setup, body_call, body_cleanup = compile_aexp ctx body in
let body_gs = gensym () in
@@ -1235,17 +1245,16 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) =
@ variable_init step_gs step_setup step_call step_cleanup
@ [iblock ([idecl CT_int64 loop_var;
icopy (CL_id (loop_var, CT_int64)) (F_id from_gs, CT_int64);
- ilabel loop_start_label;
idecl CT_unit body_gs;
- iblock (body_setup
+ iblock ([ilabel loop_start_label]
+ @ [ijump (F_op (F_id loop_var, (if is_inc then ">" else "<"), F_id to_gs), CT_bool) loop_end_label]
+ @ body_setup
@ [body_call (CL_id (body_gs, CT_unit))]
@ body_cleanup
- @ if is_inc then
- [icopy (CL_id (loop_var, CT_int64)) (F_op (F_id loop_var, "+", F_id step_gs), CT_int64);
- ijump (F_op (F_id loop_var, "<=", F_id to_gs), CT_bool) loop_start_label]
- else
- [icopy (CL_id (loop_var, CT_int64)) (F_op (F_id loop_var, "-", F_id step_gs), CT_int64);
- ijump (F_op (F_id loop_var, ">=", F_id to_gs), CT_bool) loop_start_label])])],
+ @ [icopy (CL_id (loop_var, CT_int64))
+ (F_op (F_id loop_var, (if is_inc then "+" else "-"), F_id step_gs), CT_int64)]
+ @ [igoto loop_start_label]);
+ ilabel loop_end_label])],
(fun clexp -> icopy clexp unit_fragment),
[]
@@ -1905,7 +1914,7 @@ let rec sgen_ctyp_name = function
| CT_string -> "sail_string"
| CT_real -> "real"
| CT_ref ctyp -> "ref_" ^ sgen_ctyp_name ctyp
-
+
let sgen_cval_param (frag, ctyp) =
match ctyp with
| CT_bits direction ->
@@ -1949,10 +1958,17 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) =
string (Printf.sprintf " %s = %s;" (sgen_clexp_pure clexp) (sgen_cval cval))
else
string (Printf.sprintf " COPY(%s)(%s, %s);" (sgen_ctyp_name lctyp) (sgen_clexp clexp) (sgen_cval cval))
+ else if pointer_assign lctyp rctyp then
+ let lctyp = match lctyp with
+ | CT_ref lctyp -> lctyp
+ | _ -> assert false
+ in
+ if is_stack_ctyp lctyp then
+ string (Printf.sprintf " *(%s) = %s;" (sgen_clexp_pure clexp) (sgen_cval cval))
+ else
+ string (Printf.sprintf " COPY(%s)(*(%s), %s);" (sgen_ctyp_name lctyp) (sgen_clexp clexp) (sgen_cval cval))
else
- if pointer_assign lctyp rctyp then
- string (Printf.sprintf " %s = &%s;" (sgen_clexp_pure clexp) (sgen_cval cval))
- else if is_stack_ctyp lctyp then
+ if is_stack_ctyp lctyp then
string (Printf.sprintf " %s = CONVERT_OF(%s, %s)(%s);"
(sgen_clexp_pure clexp) (sgen_ctyp_name lctyp) (sgen_ctyp_name rctyp) (sgen_cval cval))
else
@@ -2125,7 +2141,7 @@ let codegen_type_def ctx = function
| CTD_enum (id, ((first_id :: _) as ids)) ->
let codegen_eq =
let name = sgen_id id in
- string (Printf.sprintf "bool eq_%s(enum %s op1, enum %s op2) { return op1 == op2; }" name name name)
+ string (Printf.sprintf "static bool eq_%s(enum %s op1, enum %s op2) { return op1 == op2; }" name name name)
in
let codegen_undefined =
let name = sgen_id id in
@@ -2150,7 +2166,7 @@ let codegen_type_def ctx = function
string (Printf.sprintf "COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id))
in
let codegen_setter id ctors =
- string (let n = sgen_id id in Printf.sprintf "void COPY(%s)(struct %s *rop, const struct %s op)" n n n) ^^ space
+ string (let n = sgen_id id in Printf.sprintf "static void COPY(%s)(struct %s *rop, const struct %s op)" n n n) ^^ space
^^ surround 2 0 lbrace
(separate_map hardline codegen_set (Bindings.bindings ctors))
rbrace
@@ -2162,16 +2178,23 @@ let codegen_type_def ctx = function
else []
in
let codegen_init f id ctors =
- string (let n = sgen_id id in Printf.sprintf "void %s(%s)(struct %s *op)" f n n) ^^ space
+ string (let n = sgen_id id in Printf.sprintf "static void %s(%s)(struct %s *op)" f n n) ^^ space
^^ surround 2 0 lbrace
(separate hardline (Bindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat))
rbrace
in
- (*
let codegen_eq =
- string (Printf.sprintf "bool eq_%s(struct %s op1, struct %s op2) { return true; }" (sgen_id id) (sgen_id id) (sgen_id id))
+ let codegen_eq_test (id, ctyp) =
+ string (Printf.sprintf "EQUAL(%s)(op1.%s, op2.%s)" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id))
+ in
+ string (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2)" (sgen_id id) (sgen_id id) (sgen_id id))
+ ^^ space
+ ^^ surround 2 0 lbrace
+ (string "return" ^^ space
+ ^^ separate_map (string " && ") codegen_eq_test ctors
+ ^^ string ";")
+ rbrace
in
- *)
(* Generate the struct and add the generated functions *)
let codegen_ctor (id, ctyp) =
string (sgen_ctyp ctyp) ^^ space ^^ codegen_id id
@@ -2183,18 +2206,16 @@ let codegen_type_def ctx = function
rbrace
^^ semi ^^ twice hardline
^^ codegen_setter id (ctor_bindings ctors)
- ^^ if not (is_stack_ctyp struct_ctyp) then
- twice hardline
- ^^ codegen_init "CREATE" id (ctor_bindings ctors)
- ^^ twice hardline
- ^^ codegen_init "RECREATE" id (ctor_bindings ctors)
- ^^ twice hardline
- ^^ codegen_init "KILL" id (ctor_bindings ctors)
- else empty
- (*
+ ^^ (if not (is_stack_ctyp struct_ctyp) then
+ twice hardline
+ ^^ codegen_init "CREATE" id (ctor_bindings ctors)
+ ^^ twice hardline
+ ^^ codegen_init "RECREATE" id (ctor_bindings ctors)
+ ^^ twice hardline
+ ^^ codegen_init "KILL" id (ctor_bindings ctors)
+ else empty)
^^ twice hardline
^^ codegen_eq
- *)
| CTD_variant (id, tus) ->
let codegen_tu (ctor_id, ctyp) =
@@ -2215,7 +2236,7 @@ let codegen_type_def ctx = function
let codegen_init =
let n = sgen_id id in
let ctor_id, ctyp = List.hd tus in
- string (Printf.sprintf "void CREATE(%s)(struct %s *op)" n n)
+ string (Printf.sprintf "static void CREATE(%s)(struct %s *op)" n n)
^^ hardline
^^ surround 2 0 lbrace
(string (Printf.sprintf "op->kind = Kind_%s;" (sgen_id ctor_id)) ^^ hardline
@@ -2226,7 +2247,7 @@ let codegen_type_def ctx = function
in
let codegen_reinit =
let n = sgen_id id in
- string (Printf.sprintf "void RECREATE(%s)(struct %s *op) {}" n n)
+ string (Printf.sprintf "static void RECREATE(%s)(struct %s *op) {}" n n)
in
let clear_field v ctor_id ctyp =
if is_stack_ctyp ctyp then
@@ -2236,7 +2257,7 @@ let codegen_type_def ctx = function
in
let codegen_clear =
let n = sgen_id id in
- string (Printf.sprintf "void KILL(%s)(struct %s *op)" n n) ^^ hardline
+ string (Printf.sprintf "static void KILL(%s)(struct %s *op)" n n) ^^ hardline
^^ surround 2 0 lbrace
(each_ctor "op->" (clear_field "op") tus ^^ semi)
rbrace
@@ -2260,7 +2281,7 @@ let codegen_type_def ctx = function
^^ separate hardline (List.mapi tuple_set ctyps) ^^ hardline
| ctyp -> Printf.sprintf "%s op" (sgen_ctyp ctyp), empty
in
- string (Printf.sprintf "void %s(struct %s *rop, %s)" (sgen_id ctor_id) (sgen_id id) ctor_args) ^^ hardline
+ string (Printf.sprintf "static void %s(struct %s *rop, %s)" (sgen_id ctor_id) (sgen_id id) ctor_args) ^^ hardline
^^ surround 2 0 lbrace
(tuple
^^ each_ctor "rop->" (clear_field "rop") tus ^^ hardline
@@ -2281,7 +2302,7 @@ let codegen_type_def ctx = function
string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id))
^^ string (Printf.sprintf " COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id))
in
- string (Printf.sprintf "void COPY(%s)(struct %s *rop, struct %s op)" n n n) ^^ hardline
+ string (Printf.sprintf "static void COPY(%s)(struct %s *rop, struct %s op)" n n n) ^^ hardline
^^ surround 2 0 lbrace
(each_ctor "rop->" (clear_field "rop") tus
^^ semi ^^ hardline
@@ -2290,6 +2311,21 @@ let codegen_type_def ctx = function
^^ each_ctor "op." set_field tus)
rbrace
in
+ let codegen_eq =
+ let codegen_eq_test ctor_id ctyp =
+ string (Printf.sprintf "return EQUAL(%s)(op1.%s, op2.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id))
+ in
+ let rec codegen_eq_tests = function
+ | [] -> string "return false;"
+ | (ctor_id, ctyp) :: ctors ->
+ string (Printf.sprintf "if (op1.kind == Kind_%s && op2.kind == Kind_%s) " (sgen_id ctor_id) (sgen_id ctor_id)) ^^ lbrace ^^ hardline
+ ^^ jump 0 2 (codegen_eq_test ctor_id ctyp)
+ ^^ hardline ^^ rbrace ^^ string " else " ^^ codegen_eq_tests ctors
+ in
+ let n = sgen_id id in
+ string (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2) " n n n)
+ ^^ surround 2 0 lbrace (codegen_eq_tests tus) rbrace
+ in
string (Printf.sprintf "// union %s" (string_of_id id)) ^^ hardline
^^ string "enum" ^^ space
^^ string ("kind_" ^ sgen_id id) ^^ space
@@ -2317,6 +2353,8 @@ let codegen_type_def ctx = function
^^ twice hardline
^^ codegen_setter
^^ twice hardline
+ ^^ codegen_eq
+ ^^ twice hardline
^^ separate_map (twice hardline) codegen_ctor tus
(* If this is the exception type, then we setup up some global variables to deal with exceptions. *)
^^ if string_of_id id = "exception" then
@@ -2363,10 +2401,10 @@ let codegen_node id ctyp =
^^ string (Printf.sprintf "typedef struct node_%s *%s;" (sgen_id id) (sgen_id id))
let codegen_list_init id =
- string (Printf.sprintf "void CREATE(%s)(%s *rop) { *rop = NULL; }" (sgen_id id) (sgen_id id))
+ string (Printf.sprintf "static void CREATE(%s)(%s *rop) { *rop = NULL; }" (sgen_id id) (sgen_id id))
let codegen_list_clear id ctyp =
- string (Printf.sprintf "void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id))
+ string (Printf.sprintf "static void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id))
^^ string (Printf.sprintf " if (*rop == NULL) return;")
^^ (if is_stack_ctyp ctyp then empty
else string (Printf.sprintf " KILL(%s)(&(*rop)->hd);\n" (sgen_ctyp_name ctyp)))
@@ -2375,7 +2413,7 @@ let codegen_list_clear id ctyp =
^^ string "}"
let codegen_list_set id ctyp =
- string (Printf.sprintf "void internal_set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
+ string (Printf.sprintf "static void internal_set_%s(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
^^ string " if (op == NULL) { *rop = NULL; return; };\n"
^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id))
^^ (if is_stack_ctyp ctyp then
@@ -2386,14 +2424,14 @@ let codegen_list_set id ctyp =
^^ string (Printf.sprintf " internal_set_%s(&(*rop)->tl, op->tl);\n" (sgen_id id))
^^ string "}"
^^ twice hardline
- ^^ string (Printf.sprintf "void COPY(%s)(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
+ ^^ string (Printf.sprintf "static void COPY(%s)(%s *rop, const %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
^^ string (Printf.sprintf " KILL(%s)(rop);\n" (sgen_id id))
^^ string (Printf.sprintf " internal_set_%s(rop, op);\n" (sgen_id id))
^^ string "}"
let codegen_cons id ctyp =
let cons_id = mk_id ("cons#" ^ string_of_ctyp ctyp) in
- string (Printf.sprintf "void %s(%s *rop, const %s x, const %s xs) {\n" (sgen_id cons_id) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id))
+ string (Printf.sprintf "static void %s(%s *rop, const %s x, const %s xs) {\n" (sgen_id cons_id) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id))
^^ string (Printf.sprintf " *rop = malloc(sizeof(struct node_%s));\n" (sgen_id id))
^^ (if is_stack_ctyp ctyp then
string " (*rop)->hd = x;\n"
@@ -2405,9 +2443,9 @@ let codegen_cons id ctyp =
let codegen_pick id ctyp =
if is_stack_ctyp ctyp then
- string (Printf.sprintf "%s pick_%s(const %s xs) { return xs->hd; }" (sgen_ctyp ctyp) (sgen_ctyp_name ctyp) (sgen_id id))
+ string (Printf.sprintf "static %s pick_%s(const %s xs) { return xs->hd; }" (sgen_ctyp ctyp) (sgen_ctyp_name ctyp) (sgen_id id))
else
- string (Printf.sprintf "void pick_%s(%s *x, const %s xs) { COPY(%s)(x, xs->hd); }" (sgen_ctyp_name ctyp) (sgen_ctyp ctyp) (sgen_id id) (sgen_ctyp_name ctyp))
+ string (Printf.sprintf "static void pick_%s(%s *x, const %s xs) { COPY(%s)(x, xs->hd); }" (sgen_ctyp_name ctyp) (sgen_ctyp ctyp) (sgen_id id) (sgen_ctyp_name ctyp))
let codegen_list ctx ctyp =
let id = mk_id (string_of_ctyp (CT_list ctyp)) in
@@ -2435,10 +2473,10 @@ let codegen_vector ctx (direction, ctyp) =
^^ string (Printf.sprintf "typedef struct %s %s;" (sgen_id id) (sgen_id id))
in
let vector_init =
- string (Printf.sprintf "void CREATE(%s)(%s *rop) {\n rop->len = 0;\n rop->data = NULL;\n}" (sgen_id id) (sgen_id id))
+ string (Printf.sprintf "static void CREATE(%s)(%s *rop) {\n rop->len = 0;\n rop->data = NULL;\n}" (sgen_id id) (sgen_id id))
in
let vector_set =
- string (Printf.sprintf "void COPY(%s)(%s *rop, %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
+ string (Printf.sprintf "static void COPY(%s)(%s *rop, %s op) {\n" (sgen_id id) (sgen_id id) (sgen_id id))
^^ string (Printf.sprintf " KILL(%s)(rop);\n" (sgen_id id))
^^ string " rop->len = op.len;\n"
^^ string (Printf.sprintf " rop->data = malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp))
@@ -2451,7 +2489,7 @@ let codegen_vector ctx (direction, ctyp) =
^^ string "}"
in
let vector_clear =
- string (Printf.sprintf "void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id))
+ string (Printf.sprintf "static void KILL(%s)(%s *rop) {\n" (sgen_id id) (sgen_id id))
^^ (if is_stack_ctyp ctyp then empty
else
string " for (int i = 0; i < (rop->len); i++) {\n"
@@ -2461,7 +2499,7 @@ let codegen_vector ctx (direction, ctyp) =
^^ string "}"
in
let vector_update =
- string (Printf.sprintf "void vector_update_%s(%s *rop, %s op, mpz_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
+ string (Printf.sprintf "static void vector_update_%s(%s *rop, %s op, mpz_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
^^ string " int m = mpz_get_ui(n);\n"
^^ string " if (rop->data == op.data) {\n"
^^ string (if is_stack_ctyp ctyp then
@@ -2478,7 +2516,7 @@ let codegen_vector ctx (direction, ctyp) =
^^ string "}"
in
let internal_vector_update =
- string (Printf.sprintf "void internal_vector_update_%s(%s *rop, %s op, const int64_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
+ string (Printf.sprintf "static void internal_vector_update_%s(%s *rop, %s op, const int64_t n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
^^ string (if is_stack_ctyp ctyp then
" rop->data[n] = elem;\n"
else
@@ -2487,24 +2525,24 @@ let codegen_vector ctx (direction, ctyp) =
in
let vector_access =
if is_stack_ctyp ctyp then
- string (Printf.sprintf "%s vector_access_%s(%s op, mpz_t n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id))
+ string (Printf.sprintf "static %s vector_access_%s(%s op, mpz_t n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id))
^^ string " int m = mpz_get_ui(n);\n"
^^ string " return op.data[m];\n"
^^ string "}"
else
- string (Printf.sprintf "void vector_access_%s(%s *rop, %s op, mpz_t n) {\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id))
+ string (Printf.sprintf "static void vector_access_%s(%s *rop, %s op, mpz_t n) {\n" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id))
^^ string " int m = mpz_get_ui(n);\n"
^^ string (Printf.sprintf " COPY(%s)(rop, op.data[m]);\n" (sgen_ctyp_name ctyp))
^^ string "}"
in
let internal_vector_init =
- string (Printf.sprintf "void internal_vector_init_%s(%s *rop, const int64_t len) {\n" (sgen_id id) (sgen_id id))
+ string (Printf.sprintf "static void internal_vector_init_%s(%s *rop, const int64_t len) {\n" (sgen_id id) (sgen_id id))
^^ string " rop->len = len;\n"
^^ string (Printf.sprintf " rop->data = malloc(len * sizeof(%s));\n" (sgen_ctyp ctyp))
^^ string "}"
in
let vector_undefined =
- string (Printf.sprintf "void undefined_vector_%s(%s *rop, mpz_t len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
+ string (Printf.sprintf "static void undefined_vector_%s(%s *rop, mpz_t len, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_ctyp ctyp))
^^ string (Printf.sprintf " rop->len = mpz_get_ui(len);\n")
^^ string (Printf.sprintf " rop->data = malloc((rop->len) * sizeof(%s));\n" (sgen_ctyp ctyp))
^^ string " for (int i = 0; i < (rop->len); i++) {\n"
@@ -2549,12 +2587,13 @@ let codegen_def' ctx = function
^^ string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id))
| CDEF_spec (id, arg_ctyps, ret_ctyp) ->
+ let static = if !opt_static then "static " else "" in
if Env.is_extern id ctx.tc_env "c" then
empty
else if is_stack_ctyp ret_ctyp then
- string (Printf.sprintf "%s %s(%s);" (sgen_ctyp ret_ctyp) (sgen_id id) (Util.string_of_list ", " sgen_ctyp arg_ctyps))
+ string (Printf.sprintf "%s%s %s(%s);" static (sgen_ctyp ret_ctyp) (sgen_id id) (Util.string_of_list ", " sgen_ctyp arg_ctyps))
else
- string (Printf.sprintf "void %s(%s *rop, %s);" (sgen_id id) (sgen_ctyp ret_ctyp) (Util.string_of_list ", " sgen_ctyp arg_ctyps))
+ string (Printf.sprintf "%svoid %s(%s *rop, %s);" static (sgen_id id) (sgen_ctyp ret_ctyp) (Util.string_of_list ", " sgen_ctyp arg_ctyps))
| CDEF_fundef (id, ret_arg, args, instrs) as def ->
if !opt_ddump_flow_graphs then make_dot id (instrs_graph instrs) else ();
@@ -2578,10 +2617,12 @@ let codegen_def' ctx = function
match ret_arg with
| None ->
assert (is_stack_ctyp ret_ctyp);
- string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline
+ (if !opt_static then string "static " else empty)
+ ^^ string (sgen_ctyp ret_ctyp) ^^ space ^^ codegen_id id ^^ parens (string args) ^^ hardline
| Some gs ->
assert (not (is_stack_ctyp ret_ctyp));
- string "void" ^^ space ^^ codegen_id id
+ (if !opt_static then string "static " else empty)
+ ^^ string "void" ^^ space ^^ codegen_id id
^^ parens (string (sgen_ctyp ret_ctyp ^ " *" ^ sgen_id gs ^ ", ") ^^ string args)
^^ hardline
in
@@ -2594,7 +2635,8 @@ let codegen_def' ctx = function
codegen_type_def ctx ctype_def
| CDEF_startup (id, instrs) ->
- let startup_header = string (Printf.sprintf "void startup_%s(void)" (sgen_id id)) in
+ let static = if !opt_static then "static " else "" in
+ let startup_header = string (Printf.sprintf "%svoid startup_%s(void)" static (sgen_id id)) in
separate_map hardline codegen_decl instrs
^^ twice hardline
^^ startup_header ^^ hardline
@@ -2603,7 +2645,8 @@ let codegen_def' ctx = function
^^ string "}"
| CDEF_finish (id, instrs) ->
- let finish_header = string (Printf.sprintf "void finish_%s(void)" (sgen_id id)) in
+ let static = if !opt_static then "static " else "" in
+ let finish_header = string (Printf.sprintf "%svoid finish_%s(void)" static (sgen_id id)) in
separate_map hardline codegen_decl (List.filter is_decl instrs)
^^ twice hardline
^^ finish_header ^^ hardline
@@ -2620,12 +2663,12 @@ let codegen_def' ctx = function
List.concat (List.map (fun (id, ctyp) -> [iclear ctyp id]) bindings)
in
separate_map hardline (fun (id, ctyp) -> string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_id id))) bindings
- ^^ hardline ^^ string (Printf.sprintf "void create_letbind_%d(void) " number)
+ ^^ hardline ^^ string (Printf.sprintf "static void create_letbind_%d(void) " number)
^^ string "{"
^^ jump 0 2 (separate_map hardline codegen_alloc setup) ^^ hardline
^^ jump 0 2 (separate_map hardline (codegen_instr (mk_id "let") { ctx with no_raw = true }) instrs) ^^ hardline
^^ string "}"
- ^^ hardline ^^ string (Printf.sprintf "void kill_letbind_%d(void) " number)
+ ^^ hardline ^^ string (Printf.sprintf "static void kill_letbind_%d(void) " number)
^^ string "{"
^^ jump 0 2 (separate_map hardline (codegen_instr (mk_id "let") ctx) cleanup) ^^ hardline
^^ string "}"
@@ -2742,6 +2785,26 @@ let rec get_recursive_functions (Defs defs) =
match defs with
| DEF_internal_mutrec fundefs :: defs ->
IdSet.union (List.map id_of_fundef fundefs |> IdSet.of_list) (get_recursive_functions (Defs defs))
+
+ | (DEF_fundef fdef as def) :: defs ->
+ let open Rewriter in
+ let ids = ref IdSet.empty in
+ let collect_funcalls e_aux annot =
+ match e_aux with
+ | E_app (id, args) -> (ids := IdSet.add id !ids; E_aux (e_aux, annot))
+ | _ -> E_aux (e_aux, annot)
+ in
+ let map_exp = {
+ id_exp_alg with
+ e_aux = (fun (e_aux, annot) -> collect_funcalls e_aux annot)
+ } in
+ let map_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp map_exp) } in
+ let _ = rewrite_def map_defs def in
+ if IdSet.mem (id_of_fundef fdef) !ids then
+ IdSet.add (id_of_fundef fdef) (get_recursive_functions (Defs defs))
+ else
+ get_recursive_functions (Defs defs)
+
| _ :: defs -> get_recursive_functions (Defs defs)
| [] -> IdSet.empty
diff --git a/src/constant_fold.ml b/src/constant_fold.ml
index 7a35226e..45d3efe0 100644
--- a/src/constant_fold.ml
+++ b/src/constant_fold.ml
@@ -59,10 +59,25 @@ module StringMap = Map.Make(String);;
false = no folding, true = perform constant folding. *)
let optimize_constant_fold = ref false
-let exp_of_value =
+let rec fexp_of_ctor (field, value) =
+ FE_aux (FE_Fexp (mk_id field, exp_of_value value), no_annot)
+
+and exp_of_value =
let open Value in
function
| V_int n -> mk_lit_exp (L_num n)
+ | V_bit Sail_lib.B0 -> mk_lit_exp L_zero
+ | V_bit Sail_lib.B1 -> mk_lit_exp L_one
+ | V_bool true -> mk_lit_exp L_true
+ | V_bool false -> mk_lit_exp L_false
+ | V_string str -> mk_lit_exp (L_string str)
+ | V_record ctors ->
+ mk_exp (E_record (FES_aux (FES_Fexps (List.map fexp_of_ctor (StringMap.bindings ctors), false), no_annot)))
+ | V_vector vs ->
+ mk_exp (E_vector (List.map exp_of_value vs))
+ | V_tuple vs ->
+ mk_exp (E_tuple (List.map exp_of_value vs))
+ | V_unit -> mk_lit_exp L_unit
| _ -> failwith "No expression for value"
(* We want to avoid evaluating things like print statements at compile
@@ -85,15 +100,23 @@ let safe_primops =
"Elf_loader.elf_tohost"
]
-let is_literal = function
- | E_aux (E_lit _, _) -> true
+let rec is_constant (E_aux (e_aux, _)) =
+ match e_aux with
+ | E_lit _ -> true
+ | E_vector exps -> List.for_all is_constant exps
+ | E_record (FES_aux (FES_Fexps (fexps, _), _)) -> List.for_all is_constant_fexp fexps
+ | E_cast (_, exp) -> is_constant exp
+ | E_tuple exps -> List.for_all is_constant exps
| _ -> false
+and is_constant_fexp (FE_aux (FE_Fexp (_, exp), _)) = is_constant exp
+
(* Wrapper around interpreter that repeatedly steps until done. *)
let rec run ast frame =
match frame with
| Interpreter.Done (state, v) -> v
- | Interpreter.Step _ ->
+ | Interpreter.Step (lazy_str, _, _, _) ->
+ prerr_endline (Lazy.force lazy_str);
run ast (Interpreter.eval_frame ast frame)
| Interpreter.Break frame ->
run ast (Interpreter.eval_frame ast frame)
@@ -115,35 +138,57 @@ let rec run ast frame =
- Throws an exception that isn't caught.
*)
-let rewrite_constant_function_calls' ast =
+let rec rewrite_constant_function_calls' ast =
+ let rewrite_count = ref 0 in
+ let ok () = incr rewrite_count in
+ let not_ok () = decr rewrite_count in
+
let lstate, gstate =
Interpreter.initial_state ast safe_primops
in
let gstate = { gstate with Interpreter.allow_registers = false } in
+ let evaluate e_aux annot =
+ let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in
+ try
+ begin
+ let v = run ast (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in
+ let exp = exp_of_value v in
+ try (ok (); Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)) with
+ | Type_error (l, err) ->
+ (* A type error here would be unexpected, so don't ignore it! *)
+ Util.warn ("Type error when folding constants in "
+ ^ string_of_exp (E_aux (e_aux, annot))
+ ^ "\n" ^ Type_error.string_of_type_error err);
+ not_ok ();
+ E_aux (e_aux, annot)
+ end
+ with
+ (* Otherwise if anything goes wrong when trying to constant
+ fold, just continue without optimising. *)
+ | _ -> E_aux (e_aux, annot)
+ in
+
let rw_funcall e_aux annot =
match e_aux with
- | E_app (id, args) when List.for_all is_literal args ->
- begin
- let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in
- try
- begin
- let v = run ast (Interpreter.Step (lazy "", (lstate, gstate), initial_monad, [])) in
- let exp = exp_of_value v in
- try Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot) with
- | Type_error (l, err) ->
- (* A type error here would be unexpected, so don't ignore it! *)
- Util.warn ("Type error when folding constants in "
- ^ string_of_exp (E_aux (e_aux, annot))
- ^ "\n" ^ Type_error.string_of_type_error err);
- E_aux (e_aux, annot)
- end
- with
- (* Otherwise if anything goes wrong when trying to constant
- fold, just continue without optimising. *)
- | _ -> E_aux (e_aux, annot)
- end
+ | E_app (id, args) when List.for_all is_constant args ->
+ evaluate e_aux annot
+
+ | E_field (exp, id) when is_constant exp ->
+ evaluate e_aux annot
+
+ | E_if (E_aux (E_lit (L_aux (L_true, _)), _), then_exp, _) -> ok (); then_exp
+ | E_if (E_aux (E_lit (L_aux (L_false, _)), _), _, else_exp) -> ok (); else_exp
+
+ | E_let (LB_aux (LB_val (P_aux (P_id id, _), bind), _), exp) when is_constant bind ->
+ ok ();
+ subst id bind exp
+ | E_let (LB_aux (LB_val (P_aux (P_typ (typ, P_aux (P_id id, _)), annot), bind), _), exp)
+ when is_constant bind ->
+ ok ();
+ subst id (E_aux (E_cast (typ, bind), annot)) exp
+
| _ -> E_aux (e_aux, annot)
in
let rw_exp = {
@@ -151,7 +196,11 @@ let rewrite_constant_function_calls' ast =
e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)
} in
let rw_defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp) } in
- rewrite_defs_base rw_defs ast
+ let ast = rewrite_defs_base rw_defs ast in
+ (* We keep iterating until we have no more re-writes to do *)
+ if !rewrite_count > 0
+ then rewrite_constant_function_calls' ast
+ else ast
let rewrite_constant_function_calls ast =
if !optimize_constant_fold then
diff --git a/src/gen_lib/sail2_operators_bitlists.lem b/src/gen_lib/sail2_operators_bitlists.lem
index 3f8b7510..186e0a09 100644
--- a/src/gen_lib/sail2_operators_bitlists.lem
+++ b/src/gen_lib/sail2_operators_bitlists.lem
@@ -10,16 +10,16 @@ open import Sail2_prompt
val uint_maybe : list bitU -> maybe integer
let uint_maybe v = unsigned v
let uint_fail v = maybe_fail "uint" (unsigned v)
-let uint_oracle v =
- bools_of_bits_oracle v >>= (fun bs ->
+let uint_nondet v =
+ bools_of_bits_nondet v >>= (fun bs ->
return (int_of_bools false bs))
let uint v = maybe_failwith (uint_maybe v)
val sint_maybe : list bitU -> maybe integer
let sint_maybe v = signed v
let sint_fail v = maybe_fail "sint" (signed v)
-let sint_oracle v =
- bools_of_bits_oracle v >>= (fun bs ->
+let sint_nondet v =
+ bools_of_bits_nondet v >>= (fun bs ->
return (int_of_bools true bs))
let sint v = maybe_failwith (sint_maybe v)
@@ -43,13 +43,13 @@ let vector_truncate bs len = extz_bv len bs
val vec_of_bits_maybe : list bitU -> maybe (list bitU)
val vec_of_bits_fail : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e
-val vec_of_bits_oracle : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e
+val vec_of_bits_nondet : forall 'rv 'e. list bitU -> monad 'rv (list bitU) 'e
val vec_of_bits_failwith : list bitU -> list bitU
val vec_of_bits : list bitU -> list bitU
let inline vec_of_bits bits = bits
let inline vec_of_bits_maybe bits = Just bits
let inline vec_of_bits_fail bits = return bits
-let inline vec_of_bits_oracle bits = return bits
+let inline vec_of_bits_nondet bits = return bits
let inline vec_of_bits_failwith bits = bits
val access_vec_inc : list bitU -> integer -> bitU
@@ -62,13 +62,13 @@ val update_vec_inc : list bitU -> integer -> bitU -> list bitU
let update_vec_inc = update_bv_inc
let update_vec_inc_maybe v i b = Just (update_vec_inc v i b)
let update_vec_inc_fail v i b = return (update_vec_inc v i b)
-let update_vec_inc_oracle v i b = return (update_vec_inc v i b)
+let update_vec_inc_nondet v i b = return (update_vec_inc v i b)
val update_vec_dec : list bitU -> integer -> bitU -> list bitU
let update_vec_dec = update_bv_dec
let update_vec_dec_maybe v i b = Just (update_vec_dec v i b)
let update_vec_dec_fail v i b = return (update_vec_dec v i b)
-let update_vec_dec_oracle v i b = return (update_vec_dec v i b)
+let update_vec_dec_nondet v i b = return (update_vec_dec v i b)
val subrange_vec_inc : list bitU -> integer -> integer -> list bitU
let subrange_vec_inc = subrange_bv_inc
@@ -89,19 +89,19 @@ val cons_vec : bitU -> list bitU -> list bitU
let cons_vec = cons_bv
let cons_vec_maybe b v = Just (cons_vec b v)
let cons_vec_fail b v = return (cons_vec b v)
-let cons_vec_oracle b v = return (cons_vec b v)
+let cons_vec_nondet b v = return (cons_vec b v)
val cast_unit_vec : bitU -> list bitU
let cast_unit_vec = cast_unit_bv
let cast_unit_vec_maybe b = Just (cast_unit_vec b)
let cast_unit_vec_fail b = return (cast_unit_vec b)
-let cast_unit_vec_oracle b = return (cast_unit_vec b)
+let cast_unit_vec_nondet b = return (cast_unit_vec b)
val vec_of_bit : integer -> bitU -> list bitU
let vec_of_bit = bv_of_bit
let vec_of_bit_maybe len b = Just (vec_of_bit len b)
let vec_of_bit_fail len b = return (vec_of_bit len b)
-let vec_of_bit_oracle len b = return (vec_of_bit len b)
+let vec_of_bit_nondet len b = return (vec_of_bit len b)
val msb : list bitU -> bitU
let msb = most_significant
@@ -109,7 +109,7 @@ let msb = most_significant
val int_of_vec_maybe : bool -> list bitU -> maybe integer
let int_of_vec_maybe = int_of_bv
let int_of_vec_fail sign v = maybe_fail "int_of_vec" (int_of_vec_maybe sign v)
-let int_of_vec_oracle sign v = bools_of_bits_oracle v >>= (fun v -> return (int_of_bools sign v))
+let int_of_vec_nondet sign v = bools_of_bits_nondet v >>= (fun v -> return (int_of_bools sign v))
let int_of_vec sign v = maybe_failwith (int_of_vec_maybe sign v)
val string_of_bits : list bitU -> string
@@ -146,30 +146,18 @@ let mult_vec = arith_op_double_bl integerMult false
let mults_vec = arith_op_double_bl integerMult true
val add_vec_int : list bitU -> integer -> list bitU
-val adds_vec_int : list bitU -> integer -> list bitU
val sub_vec_int : list bitU -> integer -> list bitU
-val subs_vec_int : list bitU -> integer -> list bitU
val mult_vec_int : list bitU -> integer -> list bitU
-val mults_vec_int : list bitU -> integer -> list bitU
let add_vec_int l r = arith_op_bv_int integerAdd false l r
-let adds_vec_int l r = arith_op_bv_int integerAdd true l r
let sub_vec_int l r = arith_op_bv_int integerMinus false l r
-let subs_vec_int l r = arith_op_bv_int integerMinus true l r
let mult_vec_int l r = arith_op_double_bl integerMult false l (of_int (length l) r)
-let mults_vec_int l r = arith_op_double_bl integerMult true l (of_int (length l) r)
val add_int_vec : integer -> list bitU -> list bitU
-val adds_int_vec : integer -> list bitU -> list bitU
val sub_int_vec : integer -> list bitU -> list bitU
-val subs_int_vec : integer -> list bitU -> list bitU
val mult_int_vec : integer -> list bitU -> list bitU
-val mults_int_vec : integer -> list bitU -> list bitU
let add_int_vec l r = arith_op_int_bv integerAdd false l r
-let adds_int_vec l r = arith_op_int_bv integerAdd true l r
let sub_int_vec l r = arith_op_int_bv integerMinus false l r
-let subs_int_vec l r = arith_op_int_bv integerMinus true l r
let mult_int_vec l r = arith_op_double_bl integerMult false (of_int (length r) l) r
-let mults_int_vec l r = arith_op_double_bl integerMult true (of_int (length r) l) r
val add_vec_bit : list bitU -> bitU -> list bitU
val adds_vec_bit : list bitU -> bitU -> list bitU
@@ -179,25 +167,25 @@ val subs_vec_bit : list bitU -> bitU -> list bitU
let add_vec_bool l r = arith_op_bv_bool integerAdd false l r
let add_vec_bit_maybe l r = arith_op_bv_bit integerAdd false l r
let add_vec_bit_fail l r = maybe_fail "add_vec_bit" (add_vec_bit_maybe l r)
-let add_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (add_vec_bool l r))
+let add_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (add_vec_bool l r))
let add_vec_bit l r = fromMaybe (repeat [BU] (length l)) (add_vec_bit_maybe l r)
let adds_vec_bool l r = arith_op_bv_bool integerAdd true l r
let adds_vec_bit_maybe l r = arith_op_bv_bit integerAdd true l r
let adds_vec_bit_fail l r = maybe_fail "adds_vec_bit" (adds_vec_bit_maybe l r)
-let adds_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (adds_vec_bool l r))
+let adds_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (adds_vec_bool l r))
let adds_vec_bit l r = fromMaybe (repeat [BU] (length l)) (adds_vec_bit_maybe l r)
let sub_vec_bool l r = arith_op_bv_bool integerMinus false l r
let sub_vec_bit_maybe l r = arith_op_bv_bit integerMinus false l r
let sub_vec_bit_fail l r = maybe_fail "sub_vec_bit" (sub_vec_bit_maybe l r)
-let sub_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (sub_vec_bool l r))
+let sub_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (sub_vec_bool l r))
let sub_vec_bit l r = fromMaybe (repeat [BU] (length l)) (sub_vec_bit_maybe l r)
let subs_vec_bool l r = arith_op_bv_bool integerMinus true l r
let subs_vec_bit_maybe l r = arith_op_bv_bit integerMinus true l r
let subs_vec_bit_fail l r = maybe_fail "sub_vec_bit" (subs_vec_bit_maybe l r)
-let subs_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (subs_vec_bool l r))
+let subs_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (subs_vec_bool l r))
let subs_vec_bit l r = fromMaybe (repeat [BU] (length l)) (subs_vec_bit_maybe l r)
(*val add_overflow_vec : list bitU -> list bitU -> (list bitU * bitU * bitU)
@@ -236,47 +224,47 @@ let rotr = rotr_bv
val mod_vec : list bitU -> list bitU -> list bitU
val mod_vec_maybe : list bitU -> list bitU -> maybe (list bitU)
val mod_vec_fail : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e
-val mod_vec_oracle : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e
+val mod_vec_nondet : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e
let mod_vec l r = fromMaybe (repeat [BU] (length l)) (mod_bv l r)
let mod_vec_maybe l r = mod_bv l r
let mod_vec_fail l r = maybe_fail "mod_vec" (mod_bv l r)
-let mod_vec_oracle l r = of_bits_oracle (mod_vec l r)
+let mod_vec_nondet l r = of_bits_nondet (mod_vec l r)
val quot_vec : list bitU -> list bitU -> list bitU
val quot_vec_maybe : list bitU -> list bitU -> maybe (list bitU)
val quot_vec_fail : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e
-val quot_vec_oracle : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e
+val quot_vec_nondet : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e
let quot_vec l r = fromMaybe (repeat [BU] (length l)) (quot_bv l r)
let quot_vec_maybe l r = quot_bv l r
let quot_vec_fail l r = maybe_fail "quot_vec" (quot_bv l r)
-let quot_vec_oracle l r = of_bits_oracle (quot_vec l r)
+let quot_vec_nondet l r = of_bits_nondet (quot_vec l r)
val quots_vec : list bitU -> list bitU -> list bitU
val quots_vec_maybe : list bitU -> list bitU -> maybe (list bitU)
val quots_vec_fail : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e
-val quots_vec_oracle : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e
+val quots_vec_nondet : forall 'rv 'e. list bitU -> list bitU -> monad 'rv (list bitU) 'e
let quots_vec l r = fromMaybe (repeat [BU] (length l)) (quots_bv l r)
let quots_vec_maybe l r = quots_bv l r
let quots_vec_fail l r = maybe_fail "quots_vec" (quots_bv l r)
-let quots_vec_oracle l r = of_bits_oracle (quots_vec l r)
+let quots_vec_nondet l r = of_bits_nondet (quots_vec l r)
val mod_vec_int : list bitU -> integer -> list bitU
val mod_vec_int_maybe : list bitU -> integer -> maybe (list bitU)
val mod_vec_int_fail : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e
-val mod_vec_int_oracle : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e
+val mod_vec_int_nondet : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e
let mod_vec_int l r = fromMaybe (repeat [BU] (length l)) (mod_bv_int l r)
let mod_vec_int_maybe l r = mod_bv_int l r
let mod_vec_int_fail l r = maybe_fail "mod_vec_int" (mod_bv_int l r)
-let mod_vec_int_oracle l r = of_bits_oracle (mod_vec_int l r)
+let mod_vec_int_nondet l r = of_bits_nondet (mod_vec_int l r)
val quot_vec_int : list bitU -> integer -> list bitU
val quot_vec_int_maybe : list bitU -> integer -> maybe (list bitU)
val quot_vec_int_fail : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e
-val quot_vec_int_oracle : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e
+val quot_vec_int_nondet : forall 'rv 'e. list bitU -> integer -> monad 'rv (list bitU) 'e
let quot_vec_int l r = fromMaybe (repeat [BU] (length l)) (quot_bv_int l r)
let quot_vec_int_maybe l r = quot_bv_int l r
let quot_vec_int_fail l r = maybe_fail "quot_vec_int" (quot_bv_int l r)
-let quot_vec_int_oracle l r = of_bits_oracle (quot_vec_int l r)
+let quot_vec_int_nondet l r = of_bits_nondet (quot_vec_int l r)
val replicate_bits : list bitU -> integer -> list bitU
let replicate_bits = replicate_bits_bv
@@ -285,8 +273,8 @@ val duplicate : bitU -> integer -> list bitU
let duplicate = duplicate_bit_bv
let duplicate_maybe b n = Just (duplicate b n)
let duplicate_fail b n = return (duplicate b n)
-let duplicate_oracle b n =
- bool_of_bitU_oracle b >>= (fun b ->
+let duplicate_nondet b n =
+ bool_of_bitU_nondet b >>= (fun b ->
return (duplicate (bitU_of_bool b) n))
val reverse_endianness : list bitU -> list bitU
diff --git a/src/gen_lib/sail2_operators_mwords.lem b/src/gen_lib/sail2_operators_mwords.lem
index 1e4d63ba..a7fb7c50 100644
--- a/src/gen_lib/sail2_operators_mwords.lem
+++ b/src/gen_lib/sail2_operators_mwords.lem
@@ -10,21 +10,21 @@ open import Sail2_prompt
let inline uint v = unsignedIntegerFromWord v
let uint_maybe v = Just (uint v)
let uint_fail v = return (uint v)
-let uint_oracle v = return (uint v)
+let uint_nondet v = return (uint v)
let inline sint v = signedIntegerFromWord v
let sint_maybe v = Just (sint v)
let sint_fail v = return (sint v)
-let sint_oracle v = return (sint v)
+let sint_nondet v = return (sint v)
val vec_of_bits_maybe : forall 'a. Size 'a => list bitU -> maybe (mword 'a)
val vec_of_bits_fail : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e
-val vec_of_bits_oracle : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e
+val vec_of_bits_nondet : forall 'rv 'a 'e. Size 'a => list bitU -> monad 'rv (mword 'a) 'e
val vec_of_bits_failwith : forall 'a. Size 'a => list bitU -> mword 'a
val vec_of_bits : forall 'a. Size 'a => list bitU -> mword 'a
let vec_of_bits_maybe bits = of_bits bits
let vec_of_bits_fail bits = of_bits_fail bits
-let vec_of_bits_oracle bits = of_bits_oracle bits
+let vec_of_bits_nondet bits = of_bits_nondet bits
let vec_of_bits_failwith bits = of_bits_failwith bits
let vec_of_bits bits = of_bits_failwith bits
@@ -38,8 +38,8 @@ let update_vec_dec_maybe w i b = update_mword_dec w i b
let update_vec_dec_fail w i b =
bool_of_bitU_fail b >>= (fun b ->
return (update_mword_bool_dec w i b))
-let update_vec_dec_oracle w i b =
- bool_of_bitU_oracle b >>= (fun b ->
+let update_vec_dec_nondet w i b =
+ bool_of_bitU_nondet b >>= (fun b ->
return (update_mword_bool_dec w i b))
let update_vec_dec w i b = maybe_failwith (update_vec_dec_maybe w i b)
@@ -47,8 +47,8 @@ let update_vec_inc_maybe w i b = update_mword_inc w i b
let update_vec_inc_fail w i b =
bool_of_bitU_fail b >>= (fun b ->
return (update_mword_bool_inc w i b))
-let update_vec_inc_oracle w i b =
- bool_of_bitU_oracle b >>= (fun b ->
+let update_vec_inc_nondet w i b =
+ bool_of_bitU_nondet b >>= (fun b ->
return (update_mword_bool_inc w i b))
let update_vec_inc w i b = maybe_failwith (update_vec_inc_maybe w i b)
@@ -89,21 +89,21 @@ val cons_vec_bool : forall 'a 'b 'c. Size 'a, Size 'b => bool -> mword 'a -> mwo
let cons_vec_bool b w = wordFromBitlist (b :: bitlistFromWord w)
let cons_vec_maybe b w = Maybe.map (fun b -> cons_vec_bool b w) (bool_of_bitU b)
let cons_vec_fail b w = bool_of_bitU_fail b >>= (fun b -> return (cons_vec_bool b w))
-let cons_vec_oracle b w = bool_of_bitU_oracle b >>= (fun b -> return (cons_vec_bool b w))
+let cons_vec_nondet b w = bool_of_bitU_nondet b >>= (fun b -> return (cons_vec_bool b w))
let cons_vec b w = maybe_failwith (cons_vec_maybe b w)
val vec_of_bool : forall 'a. Size 'a => integer -> bool -> mword 'a
let vec_of_bool _ b = wordFromBitlist [b]
let vec_of_bit_maybe len b = Maybe.map (vec_of_bool len) (bool_of_bitU b)
let vec_of_bit_fail len b = bool_of_bitU_fail b >>= (fun b -> return (vec_of_bool len b))
-let vec_of_bit_oracle len b = bool_of_bitU_oracle b >>= (fun b -> return (vec_of_bool len b))
+let vec_of_bit_nondet len b = bool_of_bitU_nondet b >>= (fun b -> return (vec_of_bool len b))
let vec_of_bit len b = maybe_failwith (vec_of_bit_maybe len b)
val cast_bool_vec : bool -> mword ty1
let cast_bool_vec b = vec_of_bool 1 b
let cast_unit_vec_maybe b = vec_of_bit_maybe 1 b
let cast_unit_vec_fail b = bool_of_bitU_fail b >>= (fun b -> return (cast_bool_vec b))
-let cast_unit_vec_oracle b = bool_of_bitU_oracle b >>= (fun b -> return (cast_bool_vec b))
+let cast_unit_vec_nondet b = bool_of_bitU_nondet b >>= (fun b -> return (cast_bool_vec b))
let cast_unit_vec b = maybe_failwith (cast_unit_vec_maybe b)
val msb : forall 'a. Size 'a => mword 'a -> bitU
@@ -143,30 +143,18 @@ let mult_vec l r = arith_op_bv integerMult false (zeroExtend l : mword 'b) (ze
let mults_vec l r = arith_op_bv integerMult true (signExtend l : mword 'b) (signExtend r : mword 'b)
val add_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a
-val adds_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a
val sub_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a
-val subs_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a
val mult_vec_int : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b
-val mults_vec_int : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b
let add_vec_int l r = arith_op_bv_int integerAdd false l r
-let adds_vec_int l r = arith_op_bv_int integerAdd true l r
let sub_vec_int l r = arith_op_bv_int integerMinus false l r
-let subs_vec_int l r = arith_op_bv_int integerMinus true l r
let mult_vec_int l r = arith_op_bv_int integerMult false (zeroExtend l : mword 'b) r
-let mults_vec_int l r = arith_op_bv_int integerMult true (signExtend l : mword 'b) r
val add_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a
-val adds_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a
val sub_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a
-val subs_int_vec : forall 'a. Size 'a => integer -> mword 'a -> mword 'a
val mult_int_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b
-val mults_int_vec : forall 'a 'b. Size 'a, Size 'b => integer -> mword 'a -> mword 'b
let add_int_vec l r = arith_op_int_bv integerAdd false l r
-let adds_int_vec l r = arith_op_int_bv integerAdd true l r
let sub_int_vec l r = arith_op_int_bv integerMinus false l r
-let subs_int_vec l r = arith_op_int_bv integerMinus true l r
let mult_int_vec l r = arith_op_int_bv integerMult false l (zeroExtend r : mword 'b)
-let mults_int_vec l r = arith_op_int_bv integerMult true l (signExtend r : mword 'b)
val add_vec_bool : forall 'a. Size 'a => mword 'a -> bool -> mword 'a
val adds_vec_bool : forall 'a. Size 'a => mword 'a -> bool -> mword 'a
@@ -176,25 +164,25 @@ val subs_vec_bool : forall 'a. Size 'a => mword 'a -> bool -> mword 'a
let add_vec_bool l r = arith_op_bv_bool integerAdd false l r
let add_vec_bit_maybe l r = Maybe.map (add_vec_bool l) (bool_of_bitU r)
let add_vec_bit_fail l r = bool_of_bitU_fail r >>= (fun r -> return (add_vec_bool l r))
-let add_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (add_vec_bool l r))
+let add_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (add_vec_bool l r))
let add_vec_bit l r = maybe_failwith (add_vec_bit_maybe l r)
let adds_vec_bool l r = arith_op_bv_bool integerAdd true l r
let adds_vec_bit_maybe l r = Maybe.map (adds_vec_bool l) (bool_of_bitU r)
let adds_vec_bit_fail l r = bool_of_bitU_fail r >>= (fun r -> return (adds_vec_bool l r))
-let adds_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (adds_vec_bool l r))
+let adds_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (adds_vec_bool l r))
let adds_vec_bit l r = maybe_failwith (adds_vec_bit_maybe l r)
let sub_vec_bool l r = arith_op_bv_bool integerMinus false l r
let sub_vec_bit_maybe l r = Maybe.map (sub_vec_bool l) (bool_of_bitU r)
let sub_vec_bit_fail l r = bool_of_bitU_fail r >>= (fun r -> return (sub_vec_bool l r))
-let sub_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (sub_vec_bool l r))
+let sub_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (sub_vec_bool l r))
let sub_vec_bit l r = maybe_failwith (sub_vec_bit_maybe l r)
let subs_vec_bool l r = arith_op_bv_bool integerMinus true l r
let subs_vec_bit_maybe l r = Maybe.map (subs_vec_bool l) (bool_of_bitU r)
let subs_vec_bit_fail l r = bool_of_bitU_fail r >>= (fun r -> return (subs_vec_bool l r))
-let subs_vec_bit_oracle l r = bool_of_bitU_oracle r >>= (fun r -> return (subs_vec_bool l r))
+let subs_vec_bit_nondet l r = bool_of_bitU_nondet r >>= (fun r -> return (subs_vec_bool l r))
let subs_vec_bit l r = maybe_failwith (subs_vec_bit_maybe l r)
(* TODO
@@ -238,66 +226,66 @@ let rotr = rotr_mword
val mod_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a
val mod_vec_maybe : forall 'a. Size 'a => mword 'a -> mword 'a -> maybe (mword 'a)
val mod_vec_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e
-val mod_vec_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e
+val mod_vec_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e
let mod_vec l r = mod_mword l r
let mod_vec_maybe l r = mod_bv l r
let mod_vec_fail l r = maybe_fail "mod_vec" (mod_bv l r)
-let mod_vec_oracle l r =
+let mod_vec_nondet l r =
match (mod_bv l r) with
| Just w -> return w
- | Nothing -> mword_oracle ()
+ | Nothing -> mword_nondet ()
end
val quot_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a
val quot_vec_maybe : forall 'a. Size 'a => mword 'a -> mword 'a -> maybe (mword 'a)
val quot_vec_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e
-val quot_vec_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e
+val quot_vec_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e
let quot_vec l r = quot_mword l r
let quot_vec_maybe l r = quot_bv l r
let quot_vec_fail l r = maybe_fail "quot_vec" (quot_bv l r)
-let quot_vec_oracle l r =
+let quot_vec_nondet l r =
match (quot_bv l r) with
| Just w -> return w
- | Nothing -> mword_oracle ()
+ | Nothing -> mword_nondet ()
end
val quots_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> mword 'a
val quots_vec_maybe : forall 'a. Size 'a => mword 'a -> mword 'a -> maybe (mword 'a)
val quots_vec_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e
-val quots_vec_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e
+val quots_vec_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> mword 'a -> monad 'rv (mword 'a) 'e
let quots_vec l r = quots_mword l r
let quots_vec_maybe l r = quots_bv l r
let quots_vec_fail l r = maybe_fail "quots_vec" (quots_bv l r)
-let quots_vec_oracle l r =
+let quots_vec_nondet l r =
match (quots_bv l r) with
| Just w -> return w
- | Nothing -> mword_oracle ()
+ | Nothing -> mword_nondet ()
end
val mod_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a
val mod_vec_int_maybe : forall 'a. Size 'a => mword 'a -> integer -> maybe (mword 'a)
val mod_vec_int_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e
-val mod_vec_int_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e
+val mod_vec_int_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e
let mod_vec_int l r = mod_mword_int l r
let mod_vec_int_maybe l r = mod_bv_int l r
let mod_vec_int_fail l r = maybe_fail "mod_vec_int" (mod_bv_int l r)
-let mod_vec_int_oracle l r =
+let mod_vec_int_nondet l r =
match (mod_bv_int l r) with
| Just w -> return w
- | Nothing -> mword_oracle ()
+ | Nothing -> mword_nondet ()
end
val quot_vec_int : forall 'a. Size 'a => mword 'a -> integer -> mword 'a
val quot_vec_int_maybe : forall 'a. Size 'a => mword 'a -> integer -> maybe (mword 'a)
val quot_vec_int_fail : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e
-val quot_vec_int_oracle : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e
+val quot_vec_int_nondet : forall 'rv 'a 'e. Size 'a => mword 'a -> integer -> monad 'rv (mword 'a) 'e
let quot_vec_int l r = quot_mword_int l r
let quot_vec_int_maybe l r = quot_bv_int l r
let quot_vec_int_fail l r = maybe_fail "quot_vec_int" (quot_bv_int l r)
-let quot_vec_int_oracle l r =
+let quot_vec_int_nondet l r =
match (quot_bv_int l r) with
| Just w -> return w
- | Nothing -> mword_oracle ()
+ | Nothing -> mword_nondet ()
end
val replicate_bits : forall 'a 'b. Size 'a, Size 'b => mword 'a -> integer -> mword 'b
@@ -307,7 +295,7 @@ val duplicate_bool : forall 'a. Size 'a => bool -> integer -> mword 'a
let duplicate_bool b n = wordFromBitlist (repeat [b] n)
let duplicate_maybe b n = Maybe.map (fun b -> duplicate_bool b n) (bool_of_bitU b)
let duplicate_fail b n = bool_of_bitU_fail b >>= (fun b -> return (duplicate_bool b n))
-let duplicate_oracle b n = bool_of_bitU_oracle b >>= (fun b -> return (duplicate_bool b n))
+let duplicate_nondet b n = bool_of_bitU_nondet b >>= (fun b -> return (duplicate_bool b n))
let duplicate b n = maybe_failwith (duplicate_maybe b n)
val reverse_endianness : forall 'a. Size 'a => mword 'a -> mword 'a
diff --git a/src/gen_lib/sail2_prompt.lem b/src/gen_lib/sail2_prompt.lem
index 08a47052..e01cc051 100644
--- a/src/gen_lib/sail2_prompt.lem
+++ b/src/gen_lib/sail2_prompt.lem
@@ -51,31 +51,31 @@ let bool_of_bitU_fail = function
| BU -> Fail "bool_of_bitU"
end
-val bool_of_bitU_oracle : forall 'rv 'e. bitU -> monad 'rv bool 'e
-let bool_of_bitU_oracle = function
+val bool_of_bitU_nondet : forall 'rv 'e. bitU -> monad 'rv bool 'e
+let bool_of_bitU_nondet = function
| B0 -> return false
| B1 -> return true
| BU -> undefined_bool ()
end
-val bools_of_bits_oracle : forall 'rv 'e. list bitU -> monad 'rv (list bool) 'e
-let bools_of_bits_oracle bits =
+val bools_of_bits_nondet : forall 'rv 'e. list bitU -> monad 'rv (list bool) 'e
+let bools_of_bits_nondet bits =
foreachM bits []
(fun b bools ->
- bool_of_bitU_oracle b >>= (fun b ->
+ bool_of_bitU_nondet b >>= (fun b ->
return (bools ++ [b])))
-val of_bits_oracle : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monad 'rv 'a 'e
-let of_bits_oracle bits =
- bools_of_bits_oracle bits >>= (fun bs ->
+val of_bits_nondet : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monad 'rv 'a 'e
+let of_bits_nondet bits =
+ bools_of_bits_nondet bits >>= (fun bs ->
return (of_bools bs))
val of_bits_fail : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monad 'rv 'a 'e
let of_bits_fail bits = maybe_fail "of_bits" (of_bits bits)
-val mword_oracle : forall 'rv 'a 'e. Size 'a => unit -> monad 'rv (mword 'a) 'e
-let mword_oracle () =
- bools_of_bits_oracle (repeat [BU] (integerFromNat size)) >>= (fun bs ->
+val mword_nondet : forall 'rv 'a 'e. Size 'a => unit -> monad 'rv (mword 'a) 'e
+let mword_nondet () =
+ bools_of_bits_nondet (repeat [BU] (integerFromNat size)) >>= (fun bs ->
return (wordFromBitlist bs))
val whileM : forall 'rv 'vars 'e. 'vars -> ('vars -> monad 'rv bool 'e) ->
@@ -93,6 +93,16 @@ let rec untilM vars cond body =
cond vars >>= fun cond_val ->
if cond_val then return vars else untilM vars cond body
+val internal_pick : forall 'rv 'a 'e. list 'a -> monad 'rv 'a 'e
+let internal_pick xs =
+ (* Use sufficiently many undefined bits and convert into an index into the list *)
+ bools_of_bits_nondet (repeat [BU] (length_list xs)) >>= fun bs ->
+ let idx = (natFromNatural (nat_of_bools bs)) mod List.length xs in
+ match index xs idx with
+ | Just x -> return x
+ | Nothing -> Fail "internal_pick"
+ end
+
(*let write_two_regs r1 r2 vec =
let is_inc =
let is_inc_r1 = is_inc_of_reg r1 in
diff --git a/src/gen_lib/sail2_prompt_monad.lem b/src/gen_lib/sail2_prompt_monad.lem
index 745589e2..78b1615e 100644
--- a/src/gen_lib/sail2_prompt_monad.lem
+++ b/src/gen_lib/sail2_prompt_monad.lem
@@ -29,6 +29,7 @@ type monad 'regval 'a 'e =
| Read_reg of register_name * ('regval -> monad 'regval 'a 'e)
(* Request to write register *)
| Write_reg of register_name * 'regval * monad 'regval 'a 'e
+ (* Request to choose a Boolean, e.g. to resolve an undefined bit *)
| Undefined of (bool -> monad 'regval 'a 'e)
(* Print debugging or tracing information *)
| Print of string * monad 'regval 'a 'e
diff --git a/src/gen_lib/sail2_state.lem b/src/gen_lib/sail2_state.lem
index 82ac35d8..f703dead 100644
--- a/src/gen_lib/sail2_state.lem
+++ b/src/gen_lib/sail2_state.lem
@@ -41,31 +41,31 @@ let bool_of_bitU_fail = function
| BU -> failS "bool_of_bitU"
end
-val bool_of_bitU_oracleS : forall 'rv 'e. bitU -> monadS 'rv bool 'e
-let bool_of_bitU_oracleS = function
+val bool_of_bitU_nondetS : forall 'rv 'e. bitU -> monadS 'rv bool 'e
+let bool_of_bitU_nondetS = function
| B0 -> returnS false
| B1 -> returnS true
| BU -> undefined_boolS ()
end
-val bools_of_bits_oracleS : forall 'rv 'e. list bitU -> monadS 'rv (list bool) 'e
-let bools_of_bits_oracleS bits =
+val bools_of_bits_nondetS : forall 'rv 'e. list bitU -> monadS 'rv (list bool) 'e
+let bools_of_bits_nondetS bits =
foreachS bits []
(fun b bools ->
- bool_of_bitU_oracleS b >>$= (fun b ->
+ bool_of_bitU_nondetS b >>$= (fun b ->
returnS (bools ++ [b])))
-val of_bits_oracleS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e
-let of_bits_oracleS bits =
- bools_of_bits_oracleS bits >>$= (fun bs ->
+val of_bits_nondetS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e
+let of_bits_nondetS bits =
+ bools_of_bits_nondetS bits >>$= (fun bs ->
returnS (of_bools bs))
val of_bits_failS : forall 'rv 'a 'e. Bitvector 'a => list bitU -> monadS 'rv 'a 'e
let of_bits_failS bits = maybe_failS "of_bits" (of_bits bits)
-val mword_oracleS : forall 'rv 'a 'e. Size 'a => unit -> monadS 'rv (mword 'a) 'e
-let mword_oracleS () =
- bools_of_bits_oracleS (repeat [BU] (integerFromNat size)) >>$= (fun bs ->
+val mword_nondetS : forall 'rv 'a 'e. Size 'a => unit -> monadS 'rv (mword 'a) 'e
+let mword_nondetS () =
+ bools_of_bits_nondetS (repeat [BU] (integerFromNat size)) >>$= (fun bs ->
returnS (wordFromBitlist bs))
@@ -83,3 +83,13 @@ let rec untilS vars cond body s =
(body vars >>$= (fun vars s' ->
(cond vars >>$= (fun cond_val s'' ->
if cond_val then returnS vars s'' else untilS vars cond body s'')) s')) s
+
+val internal_pickS : forall 'rv 'a 'e. list 'a -> monadS 'rv 'a 'e
+let internal_pickS xs =
+ (* Use sufficiently many undefined bits and convert into an index into the list *)
+ bools_of_bits_nondetS (repeat [BU] (length_list xs)) >>$= fun bs ->
+ let idx = (natFromNatural (nat_of_bools bs)) mod List.length xs in
+ match index xs idx with
+ | Just x -> returnS x
+ | Nothing -> failS "internal_pick"
+ end
diff --git a/src/gen_lib/sail2_state_monad.lem b/src/gen_lib/sail2_state_monad.lem
index f207699f..30b296cc 100644
--- a/src/gen_lib/sail2_state_monad.lem
+++ b/src/gen_lib/sail2_state_monad.lem
@@ -13,20 +13,15 @@ type sequential_state 'regs =
memstate : memstate;
tagstate : tagstate;
write_ea : maybe (write_kind * integer * integer);
- last_exclusive_operation_was_load : bool;
- (* Random bool generator for use as an undefined bit oracle *)
- next_bool : nat -> (bool * nat);
- seed : nat |>
+ last_exclusive_operation_was_load : bool |>
-val init_state : forall 'regs. 'regs -> (nat -> (bool* nat)) -> nat -> sequential_state 'regs
-let init_state regs o s =
+val init_state : forall 'regs. 'regs -> sequential_state 'regs
+let init_state regs =
<| regstate = regs;
memstate = Map.empty;
tagstate = Map.empty;
write_ea = Nothing;
- last_exclusive_operation_was_load = false;
- next_bool = o;
- seed = s |>
+ last_exclusive_operation_was_load = false |>
type ex 'e =
| Failure of string
@@ -69,10 +64,7 @@ val failS : forall 'regs 'a 'e. string -> monadS 'regs 'a 'e
let failS msg s = {(Ex (Failure msg), s)}
val undefined_boolS : forall 'regval 'regs 'a 'e. unit -> monadS 'regs bool 'e
-let undefined_boolS () =
- readS (fun s -> s.next_bool (s.seed)) >>$= (fun (b, seed) ->
- updateS (fun s -> <| s with seed = seed |>) >>$
- returnS b)
+let undefined_boolS () = chooseS {false; true}
val exitS : forall 'regs 'e 'a. unit -> monadS 'regs 'a 'e
let exitS () = failS "exit"
diff --git a/src/gen_lib/sail2_string.lem b/src/gen_lib/sail2_string.lem
index d0e40ad4..ba3a2d51 100644
--- a/src/gen_lib/sail2_string.lem
+++ b/src/gen_lib/sail2_string.lem
@@ -31,7 +31,7 @@ let string_append = stringAppend
val maybeIntegerOfString : string -> maybe integer
let maybeIntegerOfString _ = Nothing (* TODO FIXME *)
-declare ocaml target_rep function maybeIntegerOfString = `(fun s -> match int_of_string_opt s with None -> None | Some i -> Some (Nat_big_num.of_int i))`
+declare ocaml target_rep function maybeIntegerOfString = `(fun s -> match int_of_string s with i -> Some (Nat_big_num.of_int i) | exception Failure _ -> None )`
(***********************************************
* end stuff that should be in Lem Num_extra *
diff --git a/src/gen_lib/sail2_values.lem b/src/gen_lib/sail2_values.lem
index 0f384cab..fd742fb1 100644
--- a/src/gen_lib/sail2_values.lem
+++ b/src/gen_lib/sail2_values.lem
@@ -47,12 +47,14 @@ let power_real b e = realPowInteger b e*)
val print_endline : string -> unit
let print_endline _ = ()
-declare ocaml target_rep function print_endline = `print_endline`
+(* declare ocaml target_rep function print_endline = `print_endline` *)
val prerr_endline : string -> unit
let prerr_endline _ = ()
declare ocaml target_rep function prerr_endline = `prerr_endline`
+let prerr x = prerr_endline x
+
val print_int : string -> integer -> unit
let print_int msg i = print_endline (msg ^ (stringFromInteger i))
diff --git a/src/interpreter.ml b/src/interpreter.ml
index 00846d73..99d5889a 100644
--- a/src/interpreter.ml
+++ b/src/interpreter.ml
@@ -232,7 +232,6 @@ let is_value_fexp (FE_aux (FE_Fexp (id, exp), _)) = is_value exp
let value_of_fexp (FE_aux (FE_Fexp (id, exp), _)) = (string_of_id id, value_of_exp exp)
let rec build_letchain id lbs (E_aux (_, annot) as exp) =
- (* print_endline ("LETCHAIN " ^ string_of_exp exp); *)
match lbs with
| [] -> exp
| lb :: lbs when IdSet.mem id (letbind_pat_ids lb)->
@@ -311,7 +310,6 @@ let rec step (E_aux (e_aux, annot) as orig_exp) =
else
failwith "Match failure"
-
| E_vector_subrange (vec, n, m) ->
wrap (E_app (mk_id "vector_subrange_dec", [vec; n; m]))
| E_vector_access (vec, n) ->
diff --git a/src/isail.ml b/src/isail.ml
index 593167f9..4adc1cd2 100644
--- a/src/isail.ml
+++ b/src/isail.ml
@@ -127,7 +127,12 @@ let rec run () =
print_endline ("Result = " ^ Value.string_of_value v);
current_mode := Normal
| Step (out, state, _, stack) ->
- current_mode := Evaluation (eval_frame !interactive_ast frame);
+ begin
+ try
+ current_mode := Evaluation (eval_frame !interactive_ast frame)
+ with
+ | Failure str -> print_endline str; current_mode := Normal
+ end;
run ()
| Break frame ->
print_endline "Breakpoint";
@@ -147,7 +152,12 @@ let rec run_steps n =
print_endline ("Result = " ^ Value.string_of_value v);
current_mode := Normal
| Step (out, state, _, stack) ->
- current_mode := Evaluation (eval_frame !interactive_ast frame);
+ begin
+ try
+ current_mode := Evaluation (eval_frame !interactive_ast frame)
+ with
+ | Failure str -> print_endline str; current_mode := Normal
+ end;
run_steps (n - 1)
| Break frame ->
print_endline "Breakpoint";
@@ -352,9 +362,14 @@ let handle_input' input =
print_endline ("Result = " ^ Value.string_of_value v);
current_mode := Normal
| Step (out, state, _, stack) ->
- interactive_state := state;
- current_mode := Evaluation (eval_frame !interactive_ast frame);
- print_program ()
+ begin
+ try
+ interactive_state := state;
+ current_mode := Evaluation (eval_frame !interactive_ast frame);
+ print_program ()
+ with
+ | Failure str -> print_endline str; current_mode := Normal
+ end
| Break frame ->
print_endline "Breakpoint";
current_mode := Evaluation frame
diff --git a/src/lem_interp/sail2_instr_kinds.lem b/src/lem_interp/sail2_instr_kinds.lem
index 13e5304e..3d238676 100644
--- a/src/lem_interp/sail2_instr_kinds.lem
+++ b/src/lem_interp/sail2_instr_kinds.lem
@@ -151,6 +151,10 @@ type barrier_kind =
| Barrier_RISCV_r_r
| Barrier_RISCV_rw_w
| Barrier_RISCV_w_w
+ | Barrier_RISCV_w_rw
+ | Barrier_RISCV_rw_r
+ | Barrier_RISCV_r_w
+ | Barrier_RISCV_w_r
| Barrier_RISCV_i
(* X86 *)
| Barrier_x86_MFENCE
@@ -176,6 +180,10 @@ instance (Show barrier_kind)
| Barrier_RISCV_r_r -> "Barrier_RISCV_r_r"
| Barrier_RISCV_rw_w -> "Barrier_RISCV_rw_w"
| Barrier_RISCV_w_w -> "Barrier_RISCV_w_w"
+ | Barrier_RISCV_w_rw -> "Barrier_RISCV_w_rw"
+ | Barrier_RISCV_rw_r -> "Barrier_RISCV_rw_r"
+ | Barrier_RISCV_r_w -> "Barrier_RISCV_r_w"
+ | Barrier_RISCV_w_r -> "Barrier_RISCV_w_r"
| Barrier_RISCV_i -> "Barrier_RISCV_i"
| Barrier_x86_MFENCE -> "Barrier_x86_MFENCE"
end
@@ -211,7 +219,7 @@ instance (Show instruction_kind)
| IK_mem_read read_kind -> "IK_mem_read " ^ (show read_kind)
| IK_mem_write write_kind -> "IK_mem_write " ^ (show write_kind)
| IK_mem_rmw (r, w) -> "IK_mem_rmw " ^ (show r) ^ " " ^ (show w)
- | IK_branch -> "IK_branch"
+ | IK_branch () -> "IK_branch"
| IK_trans trans_kind -> "IK_trans " ^ (show trans_kind)
| IK_simple () -> "IK_simple"
end
@@ -288,7 +296,11 @@ instance (EnumerationType barrier_kind)
| Barrier_RISCV_r_r -> 15
| Barrier_RISCV_rw_w -> 16
| Barrier_RISCV_w_w -> 17
- | Barrier_RISCV_i -> 18
- | Barrier_x86_MFENCE -> 19
+ | Barrier_RISCV_w_rw -> 18
+ | Barrier_RISCV_rw_r -> 19
+ | Barrier_RISCV_r_w -> 20
+ | Barrier_RISCV_w_r -> 21
+ | Barrier_RISCV_i -> 22
+ | Barrier_x86_MFENCE -> 23
end
end
diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml
index 5ffb1647..236c4222 100644
--- a/src/ocaml_backend.ml
+++ b/src/ocaml_backend.ml
@@ -167,10 +167,13 @@ let ocaml_lit (L_aux (lit_aux, _)) =
| L_one -> string "B1"
| L_true -> string "true"
| L_false -> string "false"
- | L_num n -> if Big_int.equal n Big_int.zero
- then string "Big_int.zero"
- else parens (string "Big_int.of_int" ^^ space
- ^^ string "(" ^^ string (Big_int.to_string n) ^^ string ")")
+ | L_num n ->
+ if Big_int.equal n Big_int.zero then
+ string "Big_int.zero"
+ else if Big_int.less_equal (Big_int.of_int min_int) n && Big_int.less_equal n (Big_int.of_int max_int) then
+ parens (string "Big_int.of_int" ^^ space ^^ parens (string (Big_int.to_string n)))
+ else
+ parens (string "Big_int.of_string" ^^ space ^^ dquotes (string (Big_int.to_string n)))
| L_undef -> failwith "undefined should have been re-written prior to ocaml backend"
| L_string str -> string_lit str
| L_real str -> parens (string "real_of_string" ^^ space ^^ dquotes (string (String.escaped str)))
@@ -389,6 +392,10 @@ let ocaml_dec_spec ctx (DEC_aux (reg, _)) =
separate space [string "let"; zencode ctx id; colon;
parens (ocaml_typ ctx typ); string "ref"; equals;
string "ref"; parens (ocaml_exp ctx (initial_value_for id ctx.register_inits))]
+ | DEC_config (id, typ, exp) ->
+ separate space [string "let"; zencode ctx id; colon;
+ parens (ocaml_typ ctx typ); string "ref"; equals;
+ string "ref"; parens (ocaml_exp ctx exp)]
| _ -> failwith "Unsupported register declaration"
let first_function = ref true
diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml
index ffe376e0..74e97a29 100644
--- a/src/pretty_print_coq.ml
+++ b/src/pretty_print_coq.ml
@@ -96,12 +96,14 @@ type context = {
kid_renames : kid KBindings.t; (* Plain tyvar -> tyvar renames *)
kid_id_renames : id KBindings.t; (* tyvar -> argument renames *)
bound_nexps : NexpSet.t;
+ build_ex_return : bool;
}
let empty_ctxt = {
early_ret = false;
kid_renames = KBindings.empty;
kid_id_renames = KBindings.empty;
- bound_nexps = NexpSet.empty
+ bound_nexps = NexpSet.empty;
+ build_ex_return = false;
}
let langlebar = string "<|"
@@ -135,7 +137,10 @@ let rec fix_id remove_tick name = match name with
| "GT"
| "EQ"
| "Z"
+ | "O"
+ | "S"
| "mod"
+ | "M"
-> name ^ "'"
| _ ->
if String.contains name '#' then
@@ -146,15 +151,17 @@ let rec fix_id remove_tick name = match name with
fix_id remove_tick (String.concat "__" (Util.split_on_char '^' name))
else if name.[0] = '\'' then
let var = String.sub name 1 (String.length name - 1) in
- if remove_tick then var else (var ^ "'")
+ if remove_tick then fix_id remove_tick var else (var ^ "'")
else if is_number_char(name.[0]) then
("v" ^ name ^ "'")
else name
-let doc_id (Id_aux(i,_)) =
+let string_id (Id_aux(i,_)) =
match i with
- | Id i -> string (fix_id false i)
- | DeIid x -> string (Util.zencode_string ("op " ^ x))
+ | Id i -> fix_id false i
+ | DeIid x -> Util.zencode_string ("op " ^ x)
+
+let doc_id id = string (string_id id)
let doc_id_type (Id_aux(i,_)) =
match i with
@@ -318,7 +325,7 @@ let drop_duplicate_atoms kids ty =
in aux_typ ty
(* TODO: parens *)
-let rec doc_nc ctx (NC_aux (nc,_)) =
+let rec doc_nc_prop ctx (NC_aux (nc,_)) =
match nc with
| NC_equal (ne1, ne2) -> doc_op equals (doc_nexp ctx ne1) (doc_nexp ctx ne2)
| NC_bounded_ge (ne1, ne2) -> doc_op (string ">=") (doc_nexp ctx ne1) (doc_nexp ctx ne2)
@@ -328,11 +335,27 @@ let rec doc_nc ctx (NC_aux (nc,_)) =
separate space [string "In"; doc_var_lem ctx kid;
brackets (separate (string "; ")
(List.map (fun i -> string (Nat_big_num.to_string i)) is))]
- | NC_or (nc1, nc2) -> doc_op (string "\\/") (doc_nc ctx nc1) (doc_nc ctx nc2)
- | NC_and (nc1, nc2) -> doc_op (string "/\\") (doc_nc ctx nc1) (doc_nc ctx nc2)
+ | NC_or (nc1, nc2) -> doc_op (string "\\/") (doc_nc_prop ctx nc1) (doc_nc_prop ctx nc2)
+ | NC_and (nc1, nc2) -> doc_op (string "/\\") (doc_nc_prop ctx nc1) (doc_nc_prop ctx nc2)
| NC_true -> string "True"
| NC_false -> string "False"
+(* TODO: parens *)
+let rec doc_nc_exp ctx (NC_aux (nc,_)) =
+ match nc with
+ | NC_equal (ne1, ne2) -> doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)
+ | NC_bounded_ge (ne1, ne2) -> doc_op (string ">=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)
+ | NC_bounded_le (ne1, ne2) -> doc_op (string "<=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)
+ | NC_not_equal (ne1, ne2) -> string "negb" ^^ space ^^ parens (doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2))
+ | NC_set (kid, is) -> (* TODO: is this a good translation? *)
+ separate space [string "member_Z_list"; doc_var_lem ctx kid;
+ brackets (separate (string "; ")
+ (List.map (fun i -> string (Nat_big_num.to_string i)) is))]
+ | NC_or (nc1, nc2) -> doc_op (string "||") (doc_nc_exp ctx nc1) (doc_nc_exp ctx nc2)
+ | NC_and (nc1, nc2) -> doc_op (string "&&") (doc_nc_exp ctx nc1) (doc_nc_exp ctx nc2)
+ | NC_true -> string "true"
+ | NC_false -> string "false"
+
let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) =
match typ with
| Typ_app(Id_aux (Id "range", _), [Typ_arg_aux(Typ_arg_nexp low,_);
@@ -347,7 +370,7 @@ let maybe_expand_range_type (Typ_aux (typ,l) as full_typ) =
let expand_range_type typ = Util.option_default typ (maybe_expand_range_type typ)
let doc_arithfact ctxt nc =
- string "ArithFact" ^^ space ^^ parens (doc_nc ctxt nc)
+ string "ArithFact" ^^ space ^^ parens (doc_nc_prop ctxt nc)
(* When making changes here, check whether they affect lem_tyvars_of_typ *)
let doc_typ, doc_atomic_typ =
@@ -381,7 +404,7 @@ let doc_typ, doc_atomic_typ =
let tpp = match elem_typ with
| Typ_aux (Typ_id (Id_aux (Id "bit",_)),_) ->
string "mword " ^^ doc_nexp ctx (nexp_simp m)
- | _ -> string "list" ^^ space ^^ typ elem_typ in
+ | _ -> string "vec" ^^ space ^^ typ elem_typ ^^ space ^^ doc_nexp ctx (nexp_simp m) in
if atyp_needed then parens tpp else tpp
| Typ_app(Id_aux (Id "register", _), [Typ_arg_aux (Typ_arg_typ etyp, _)]) ->
let tpp = string "register_ref regstate register_value " ^^ typ etyp in
@@ -424,7 +447,7 @@ let doc_typ, doc_atomic_typ =
List.fold_left add_tyvar tpp kids
| None ->
match nc with
- | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids)
+(* | NC_aux (NC_true,_) -> List.fold_left add_tyvar (string "Z") (List.tl kids)*)
| _ -> List.fold_left add_tyvar (doc_arithfact ctx nc) kids
end
and doc_typ_arg (Typ_arg_aux(t,_)) = match t with
@@ -537,9 +560,9 @@ let doc_typquant_items ctx delimit (TypQ_aux (tq,_)) =
let doc_typquant_items_separate ctx delimit (TypQ_aux (tq,_)) =
match tq with
| TypQ_tq qis ->
- separate_opt space (doc_quant_item_id ctx delimit) qis,
- separate_opt space (doc_quant_item_constr ctx delimit) qis
- | TypQ_no_forall -> empty, empty
+ Util.map_filter (doc_quant_item_id ctx delimit) qis,
+ Util.map_filter (doc_quant_item_constr ctx delimit) qis
+ | TypQ_no_forall -> [], []
let doc_typquant ctx (TypQ_aux(tq,_)) typ = match tq with
| TypQ_tq ((_ :: _) as qs) ->
@@ -687,14 +710,34 @@ let typ_id_of (Typ_aux (typ, l)) = match typ with
| Typ_app (id, _) -> id
| _ -> raise (Reporting_basic.err_unreachable l "failed to get type id")
+(* TODO: maybe Nexp_exp, division? *)
+(* Evaluation of constant nexp subexpressions, because Coq will be able to do those itself *)
+let rec nexp_const_eval (Nexp_aux (n,l) as nexp) =
+ let binop f re l n1 n2 =
+ match nexp_const_eval n1, nexp_const_eval n2 with
+ | Nexp_aux (Nexp_constant c1,_), Nexp_aux (Nexp_constant c2,_) ->
+ Nexp_aux (Nexp_constant (f c1 c2),l)
+ | n1', n2' -> Nexp_aux (re n1' n2',l)
+ in
+ let unop f re l n1 =
+ match nexp_const_eval n1 with
+ | Nexp_aux (Nexp_constant c1,_) -> Nexp_aux (Nexp_constant (f c1),l)
+ | n1' -> Nexp_aux (re n1',l)
+ in
+ match n with
+ | Nexp_times (n1,n2) -> binop Big_int.mul (fun n1 n2 -> Nexp_times (n1,n2)) l n1 n2
+ | Nexp_sum (n1,n2) -> binop Big_int.add (fun n1 n2 -> Nexp_sum (n1,n2)) l n1 n2
+ | Nexp_minus (n1,n2) -> binop Big_int.sub (fun n1 n2 -> Nexp_minus (n1,n2)) l n1 n2
+ | Nexp_neg n1 -> unop Big_int.negate (fun n -> Nexp_neg n) l n1
+ | _ -> nexp
+
(* Decide whether two nexps used in a vector size are similar; if not
a cast will be inserted *)
-let similar_nexps n1 n2 =
+let similar_nexps env n1 n2 =
let rec same_nexp_shape (Nexp_aux (n1,_)) (Nexp_aux (n2,_)) =
match n1, n2 with
- | Nexp_id _, Nexp_id _
- | Nexp_var _, Nexp_var _
- -> true
+ | Nexp_id _, Nexp_id _ -> true
+ | Nexp_var k1, Nexp_var k2 -> prove env (nc_eq (nvar k1) (nvar k2))
| Nexp_constant c1, Nexp_constant c2 -> Nat_big_num.equal c1 c2
| Nexp_app (f1,args1), Nexp_app (f2,args2) ->
Id.compare f1 f2 == 0 && List.for_all2 same_nexp_shape args1 args2
@@ -706,7 +749,48 @@ let similar_nexps n1 n2 =
| Nexp_neg n1, Nexp_neg n2
-> same_nexp_shape n1 n2
| _ -> false
- in if same_nexp_shape n1 n2 then true else false
+ in if same_nexp_shape (nexp_const_eval n1) (nexp_const_eval n2) then true else false
+
+let constraint_fns = ["Z.leb"; "Z.geb"; "Z.ltb"; "Z.gtb"; "Z.eqb"; "neq_atom"]
+
+let condition_produces_constraint exp =
+ (* Cheat a little - this isn't quite the right environment for subexpressions
+ but will have all of the relevant functions in it. *)
+ let env = env_of exp in
+ Rewriter.fold_exp
+ { (Rewriter.pure_exp_alg false (||)) with
+ Rewriter.e_app = fun (f,bs) ->
+ List.exists (fun x -> x) bs ||
+ (let name = if Env.is_extern f env "coq"
+ then Env.get_extern f env "coq"
+ else string_id f in
+ List.exists (fun id -> String.compare name id == 0) constraint_fns)
+ } exp
+
+(* For most functions whose return types are non-trivial atoms we return a
+ dependent pair with a proof that the result is the expected integer. This
+ is redundant for basic arithmetic functions and functions which we unfold
+ in the constraint solver. *)
+let no_Z_proof_fns = ["Z.add"; "Z.sub"; "Z.opp"; "Z.mul"; "length_mword"; "length"]
+
+let is_no_Z_proof_fn env id =
+ if Env.is_extern id env "coq"
+ then
+ let s = Env.get_extern id env "coq" in
+ List.exists (fun x -> String.compare x s == 0) no_Z_proof_fns
+ else false
+
+let replace_atom_return_type ret_typ =
+ (* TODO: more complex uses of atom *)
+ match ret_typ with
+ | Typ_aux (Typ_app (Id_aux (Id "atom",_), [Typ_arg_aux (Typ_arg_nexp nexp,_)]),l) ->
+ let kid = mk_kid "_retval" in (* TODO: collision avoidance *)
+ true, Typ_aux (Typ_exist ([kid], nc_eq (nvar kid) nexp, atom_typ (nvar kid)),Generated l)
+ | Typ_aux (Typ_id (Id_aux (Id "nat",_)),l) ->
+ let kid = mk_kid "_retval" in
+ true, Typ_aux (Typ_exist ([kid], nc_gteq (nvar kid) (nconstant Nat_big_num.zero), atom_typ (nvar kid)),Generated l)
+ | _ -> false, ret_typ
+
let prefix_recordtype = true
let report = Reporting_basic.err_unreachable
@@ -815,7 +899,7 @@ let doc_exp_lem, doc_let_lem =
| Id_aux (Id "foreach", _) ->
begin
match args with
- | [exp1; exp2; exp3; ord_exp; vartuple; body] ->
+ | [from_exp; to_exp; step_exp; ord_exp; vartuple; body] ->
let loopvar, body = match body with
| E_aux (E_let (LB_aux (LB_val (_, _), _),
E_aux (E_let (LB_aux (LB_val (_, _), _),
@@ -826,13 +910,13 @@ let doc_exp_lem, doc_let_lem =
| (P_aux (P_id id, _))), _), _),
body), _), _), _)), _)), _) -> id, body
| _ -> raise (Reporting_basic.err_unreachable l ("Unable to find loop variable in " ^ string_of_exp body)) in
- let step = match ord_exp with
- | E_aux (E_lit (L_aux (L_false, _)), _) ->
- parens (separate space [string "integerNegate"; expY exp3])
- | _ -> expY exp3
+ let dir = match ord_exp with
+ | E_aux (E_lit (L_aux (L_false, _)), _) -> "_down"
+ | E_aux (E_lit (L_aux (L_true, _)), _) -> "_up"
+ | _ -> raise (Reporting_basic.err_unreachable l ("Unexpected loop direction " ^ string_of_exp ord_exp))
in
- let combinator = if effectful (effect_of body) then "foreachM" else "foreach" in
- let indices_pp = parens (separate space [string "index_list"; expY exp1; expY exp2; step]) in
+ let combinator = if effectful (effect_of body) then "foreach_ZM" else "foreach_Z" in
+ let combinator = combinator ^ dir in
let used_vars_body = find_e_ids body in
let body_lambda =
(* Work around indentation issues in Lem when translating
@@ -840,18 +924,20 @@ let doc_exp_lem, doc_let_lem =
match fst (uncast_exp vartuple) with
| E_aux (E_tuple _, _)
when not (IdSet.mem (mk_id "varstup") used_vars_body)->
- separate space [string "fun"; doc_id loopvar; string "varstup"; bigarrow]
+ separate space [string "fun"; doc_id loopvar; string "_"; string "varstup"; bigarrow]
^^ break 1 ^^
- separate space [string "let"; expY vartuple; string ":= varstup in"]
+ separate space [string "let"; squote ^^ expY vartuple; string ":= varstup in"]
| E_aux (E_lit (L_aux (L_unit, _)), _)
when not (IdSet.mem (mk_id "unit_var") used_vars_body) ->
- separate space [string "fun"; doc_id loopvar; string "unit_var"; bigarrow]
+ separate space [string "fun"; doc_id loopvar; string "_"; string "unit_var"; bigarrow]
| _ ->
- separate space [string "fun"; doc_id loopvar; expY vartuple; bigarrow]
+ separate space [string "fun"; doc_id loopvar; string "_"; expY vartuple; bigarrow]
in
parens (
(prefix 2 1)
- ((separate space) [string combinator; indices_pp; expY vartuple])
+ ((separate space) [string combinator;
+ expY from_exp; expY to_exp; expY step_exp;
+ expY vartuple])
(parens
(prefix 2 1 (group body_lambda) (expN body))
)
@@ -879,7 +965,7 @@ let doc_exp_lem, doc_let_lem =
| E_aux (E_tuple _, _)
when not (IdSet.mem (mk_id "varstup") used_vars_body)->
separate space [string "fun varstup"; bigarrow] ^^ break 1 ^^
- separate space [string "let"; expY varstuple; string ":= varstup in"]
+ separate space [string "let"; squote ^^ expY varstuple; string ":= varstup in"]
| E_aux (E_lit (L_aux (L_unit, _)), _)
when not (IdSet.mem (mk_id "unit_var") used_vars_body) ->
separate space [string "fun unit_var"; bigarrow]
@@ -946,17 +1032,35 @@ let doc_exp_lem, doc_let_lem =
(* TODO: more sophisticated check *)
match destruct_exist env arg_ty, destruct_exist env typ_from_fn with
| Some _, None -> parens (string "projT1 " ^^ arg_pp)
+ (* Usually existentials have already been built elsewhere, but this
+ is useful for (e.g.) ranges *)
+ | None, Some _ -> parens (string "build_ex " ^^ arg_pp)
| _, _ -> arg_pp
in
let epp = hang 2 (flow (break 1) (call :: List.map2 doc_arg args arg_typs)) in
- (* Unpack existential result *)
- let inst = instantiation_of full_exp in
+ (* Decide whether to unpack an existential result, pack one, or cast.
+ To do this we compare the expected type stored in the checked expression
+ with the inferred type. *)
+ let inst =
+ match instantiation_of_without_type full_exp with
+ | x -> x
+ (* Not all function applications can be inferred, so try falling back to the
+ type inferred when we know the target type.
+ TODO: there are probably some edge cases where this won't pick up a need
+ to cast. *)
+ | exception _ -> instantiation_of full_exp
+ in
let inst = KBindings.fold (fun k u m -> KBindings.add (orig_kid k) u m) inst KBindings.empty in
- let ret_typ_inst = subst_unifiers inst ret_typ in
+ let ret_typ_inst =
+ subst_unifiers inst ret_typ
+ in
let unpack,build_ex,autocast =
let ann_typ = Env.expand_synonyms env (typ_of_annot (l,annot)) in
let ann_typ = expand_range_type ann_typ in
let ret_typ_inst = expand_range_type (Env.expand_synonyms env ret_typ_inst) in
+ let ret_typ_inst =
+ if is_no_Z_proof_fn env f then ret_typ_inst
+ else snd (replace_atom_return_type ret_typ_inst) in
let unpack, build_ex, in_typ, out_typ =
match ret_typ_inst, ann_typ with
| Typ_aux (Typ_exist (_,_,t1),_), Typ_aux (Typ_exist (_,_,t2),_) ->
@@ -968,10 +1072,12 @@ let doc_exp_lem, doc_let_lem =
| t1, t2 -> false,false,t1,t2
in
let autocast =
- match destruct_vector env in_typ, destruct_vector env out_typ with
- | Some (n1,_,t1), Some (n2,_,t2)
- when is_bit_typ t1 && is_bit_typ t2 ->
- not (similar_nexps n1 n2)
+ (* Avoid using helper functions which simplify the nexps *)
+ is_bitvector_typ in_typ && is_bitvector_typ out_typ &&
+ match in_typ, out_typ with
+ | Typ_aux (Typ_app (_,[Typ_arg_aux (Typ_arg_nexp n1,_);_;_]),_),
+ Typ_aux (Typ_app (_,[Typ_arg_aux (Typ_arg_nexp n2,_);_;_]),_) ->
+ not (similar_nexps env n1 n2)
| _ -> false
in unpack,build_ex,autocast
in
@@ -1048,13 +1154,33 @@ let doc_exp_lem, doc_let_lem =
(doc_fexp ctxt recordtyp) fexps)) in
if aexp_needed then parens epp else epp
| E_record_update(e,(FES_aux(FES_Fexps(fexps,_),_))) ->
- let recordtyp = match annot with
+ let recordtyp, env = match annot with
| Some (env, Typ_aux (Typ_id tid,_), _)
| Some (env, Typ_aux (Typ_app (tid, _), _), _)
when Env.is_record tid env ->
- tid
+ tid, env
| _ -> raise (report l ("cannot get record type from annot " ^ string_of_annot annot ^ " of exp " ^ string_of_exp full_exp)) in
- enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps))
+ if List.length fexps > 1 then
+ let _,fields = Env.get_record recordtyp env in
+ let var, let_pp =
+ match e with
+ | E_aux (E_id id,_) -> id, empty
+ | _ -> let v = mk_id "_record" in (* TODO: collision avoid *)
+ v, separate space [string "let "; doc_id v; coloneq; top_exp ctxt true e; string "in"] ^^ break 1
+ in
+ let doc_field (_,id) =
+ match List.find (fun (FE_aux (FE_Fexp (id',_),_)) -> Id.compare id id' == 0) fexps with
+ | fexp -> doc_fexp ctxt recordtyp fexp
+ | exception Not_found ->
+ let fname =
+ if prefix_recordtype && string_of_id recordtyp <> "regstate"
+ then (string (string_of_id recordtyp ^ "_")) ^^ doc_id id
+ else doc_id id in
+ doc_op coloneq fname (doc_id var ^^ dot ^^ parens fname)
+ in let_pp ^^ enclose_record (align (separate_map (semi_sp ^^ break 1)
+ doc_field fields))
+ else
+ enclose_record_update (doc_op (string "with") (expY e) (separate_map semi_sp (doc_fexp ctxt recordtyp) fexps))
| E_vector exps ->
let t = Env.base_typ_of (env_of full_exp) (typ_of full_exp) in
let start, (len, order, etyp) =
@@ -1079,7 +1205,9 @@ let doc_exp_lem, doc_let_lem =
if is_bit_typ etyp then
let bepp = string "vec_of_bits" ^^ space ^^ align epp in
(align (group (prefix 0 1 bepp (doc_tannot_lem ctxt (env_of full_exp) false t))), true)
- else (epp,aexp_needed) in
+ else
+ let vepp = string "vec_of_list_len" ^^ space ^^ align epp in
+ (vepp,aexp_needed) in
if aexp_needed then parens (align epp) else epp
| E_vector_update(v,e1,e2) ->
raise (Reporting_basic.err_unreachable l
@@ -1100,7 +1228,8 @@ let doc_exp_lem, doc_let_lem =
if effectful (effect_of e) then
let try_catch = if ctxt.early_ret then "try_catchR" else "try_catch" in
let epp =
- group ((separate space [string try_catch; expY e; string "(function "]) ^/^
+ (* TODO capture avoidance for __catch_val *)
+ group ((separate space [string try_catch; expY e; string "(fun __catch_val => match __catch_val with "]) ^/^
(separate_map (break 1) (doc_case ctxt exc_typ) pexps) ^/^
(string "end)")) in
if aexp_needed then parens (align epp) else align epp
@@ -1119,24 +1248,37 @@ let doc_exp_lem, doc_let_lem =
| E_var(lexp, eq_exp, in_exp) ->
raise (report l "E_vars should have been removed before pretty-printing")
| E_internal_plet (pat,e1,e2) ->
- let epp =
- let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in
- let middle =
- match fst (untyp_pat pat) with
- | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) ->
- string ">>"
- | P_aux (P_id id,_) ->
- separate space [string ">>= fun"; doc_id id; bigarrow]
- | P_aux (P_typ (typ, P_aux (P_id id,_)),_) ->
- separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt typ; bigarrow]
- | _ ->
- separate space [string ">>= fun"; squote ^^ doc_pat ctxt true (pat, typ_of e1); bigarrow]
- in
- infix 0 1 middle (expV b e1) (expN e2)
- in
- if aexp_needed then parens (align epp) else epp
+ begin
+ match pat, e1 with
+ | (P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _)),
+ (E_aux (E_assert (assert_e1,assert_e2),_)) ->
+ let epp = liftR (separate space [string "assert_exp'"; expY assert_e1; expY assert_e2]) in
+ let epp = infix 0 1 (string ">>= fun _ =>") epp (expN e2) in
+ if aexp_needed then parens (align epp) else align epp
+ | _ ->
+ let epp =
+ let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in
+ let middle =
+ match fst (untyp_pat pat) with
+ | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) ->
+ string ">>"
+ | P_aux (P_id id,_) ->
+ separate space [string ">>= fun"; doc_id id; bigarrow]
+ | P_aux (P_typ (typ, P_aux (P_id id,_)),_) ->
+ separate space [string ">>= fun"; doc_id id; colon; doc_typ ctxt typ; bigarrow]
+ | _ ->
+ separate space [string ">>= fun"; squote ^^ doc_pat ctxt true (pat, typ_of e1); bigarrow]
+ in
+ infix 0 1 middle (expV b e1) (expN e2)
+ in
+ if aexp_needed then parens (align epp) else epp
+ end
| E_internal_return (e1) ->
- wrap_parens (align (separate space [string "returnm"; expY e1]))
+ let e1pp = expY e1 in
+ let valpp = if ctxt.build_ex_return
+ then parens (string "build_ex" ^^ space ^^ e1pp)
+ else e1pp in
+ wrap_parens (align (separate space [string "returnm"; valpp]))
| E_sizeof nexp ->
(match nexp_simp nexp with
| Nexp_aux (Nexp_constant i, _) -> doc_lit (L_aux (L_num i, l))
@@ -1153,7 +1295,7 @@ let doc_exp_lem, doc_let_lem =
parens (doc_typ ctxt (typ_of full_exp));
parens (doc_typ ctxt (typ_of r))] in
align (parens (string "early_return" ^//^ expV true r ^//^ ta))
- | E_constraint _ -> string "true"
+ | E_constraint nc -> wrap_parens (doc_nc_exp ctxt nc)
| E_comment _ | E_comment_struc _ -> empty
| E_internal_cast _ | E_internal_exp _ | E_sizeof_internal _
| E_internal_exp_user _ | E_internal_value _ ->
@@ -1168,7 +1310,9 @@ let doc_exp_lem, doc_let_lem =
| _ -> prefix 2 1 (string "else") (top_exp ctxt false e)
in
(prefix 2 1
- (soft_surround 2 1 if_pp (string "sumbool_of_bool" ^^ space ^^ parens (top_exp ctxt true c)) (string "then"))
+ (soft_surround 2 1 if_pp
+ ((if condition_produces_constraint c then string "sumbool_of_bool" ^^ space else empty)
+ ^^ parens (top_exp ctxt true c)) (string "then"))
(top_exp ctxt false t)) ^^
break 1 ^^
else_pp
@@ -1404,17 +1548,28 @@ let demote_as_pattern i (P_aux (_,p_annot) as pat,typ) =
E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)),p_annot),e),e_ann)
else (pat,typ), fun e -> e
-(* Ideally we'd remove the duplication between type variables and atom
- arguments, but for now we just add an equality constraint. *)
+(* Add equality constraints between arguments and nexps, except in the case
+ that they've been merged. *)
-let atom_constraint ctxt (pat, typ) =
+let rec atom_constraint ctxt (pat, typ) =
let typ = Env.base_typ_of (pat_env_of pat) typ in
match pat, typ with
| P_aux (P_id id, _),
Typ_aux (Typ_app (Id_aux (Id "atom",_),
- [Typ_arg_aux (Typ_arg_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_) ->
+ [Typ_arg_aux (Typ_arg_nexp nexp,_)]),_) ->
+ (match nexp with
+ (* When the kid is mapped to the id, we don't need a constraint *)
+ | Nexp_aux (Nexp_var kid,_)
+ when (try Id.compare (KBindings.find kid ctxt.kid_id_renames) id == 0 with _ -> false) ->
+ None
+ | _ ->
+ Some (bquote ^^ braces (string "ArithFact" ^^ space ^^
+ parens (doc_op equals (doc_id id) (doc_nexp ctxt nexp)))))
+ | P_aux (P_id id, _),
+ Typ_aux (Typ_id (Id_aux (Id "nat",_)),_) ->
Some (bquote ^^ braces (string "ArithFact" ^^ space ^^
- parens (doc_op equals (doc_id id) (doc_var_lem ctxt kid))))
+ parens (doc_op (string ">=") (doc_id id) (string "0"))))
+ | P_aux (P_typ (_,p),_), _ -> atom_constraint ctxt (p, typ)
| _ -> None
let all_ids pexp =
@@ -1485,14 +1640,13 @@ let merge_kids_atoms pats =
let gone,map,_ = List.fold_left try_eliminate (KidSet.empty, KBindings.empty, KidSet.empty) pats in
gone,map
-let doc_binder ctxt (P_aux (p,ann) as pat, typ) =
- let env = env_of_annot ann in
- let exp_typ = Env.expand_synonyms env typ in
- match p with
- | P_id id
- | P_typ (_,P_aux (P_id id,_)) when Util.is_none (is_auto_decomposed_exist env exp_typ) ->
- parens (separate space [doc_id id; colon; doc_typ ctxt typ])
- | _ -> squote ^^ parens (separate space [doc_pat ctxt true (pat, exp_typ); colon; doc_typ ctxt typ])
+let merge_var_patterns map pats =
+ let map,pats = List.fold_left (fun (map,pats) (pat, typ) ->
+ match pat with
+ | P_aux (P_var (P_aux (P_id id,_), TP_aux (TP_var kid,_)),ann) ->
+ KBindings.add kid id map, (P_aux (P_id id,ann), typ) :: pats
+ | _ -> map, (pat,typ)::pats) (map,[]) pats
+ in map, List.rev pats
let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) =
let (tq,typ) = Env.get_val_spec_orig id (env_of_annot annot) in
@@ -1500,38 +1654,70 @@ let doc_funcl (FCL_aux(FCL_Funcl(id, pexp), annot)) =
| Typ_aux (Typ_fn (arg_typ, ret_typ, eff),_) -> arg_typ, ret_typ, eff
| _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type")
in
+ let build_ex, ret_typ = replace_atom_return_type ret_typ in
let ids_to_avoid = all_ids pexp in
let kids_used = tyvars_of_typquant tq in
let pat,guard,exp,(l,_) = destruct_pexp pexp in
let pats, bind = untuple_args_pat arg_typ pat in
let pats, binds = List.split (Util.list_mapi demote_as_pattern pats) in
let eliminated_kids, kid_to_arg_rename = merge_kids_atoms pats in
+ let kid_to_arg_rename, pats = merge_var_patterns kid_to_arg_rename pats in
let kids_used = KidSet.diff kids_used eliminated_kids in
let ctxt =
{ early_ret = contains_early_return exp;
kid_renames = mk_kid_renames ids_to_avoid kids_used;
kid_id_renames = kid_to_arg_rename;
- bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ) } in
+ bound_nexps = NexpSet.union (lem_nexps_of_typ typ) (typeclass_nexps typ);
+ build_ex_return = effectful eff && build_ex;
+ } in
(* Put the constraints after pattern matching so that any type variable that's
been replaced by one of the term-level arguments is bound. *)
let quantspp, constrspp = doc_typquant_items_separate ctxt braces tq in
let exp = List.fold_left (fun body f -> f body) (bind exp) binds in
- let patspp = separate_map space (doc_binder ctxt) pats in
- let atom_constr_pp = separate_opt space (atom_constraint ctxt) pats in
+ let used_a_pattern = ref false in
+ let doc_binder (P_aux (p,ann) as pat, typ) =
+ let env = env_of_annot ann in
+ let exp_typ = Env.expand_synonyms env typ in
+ match p with
+ | P_id id
+ | P_typ (_,P_aux (P_id id,_)) when Util.is_none (is_auto_decomposed_exist env exp_typ) ->
+ parens (separate space [doc_id id; colon; doc_typ ctxt typ])
+ | _ ->
+ (used_a_pattern := true;
+ squote ^^ parens (separate space [doc_pat ctxt true (pat, exp_typ); colon; doc_typ ctxt typ]))
+ in
+ let patspp = separate_map space doc_binder pats in
+ let atom_constrs = Util.map_filter (atom_constraint ctxt) pats in
+ let atom_constr_pp = separate space atom_constrs in
let retpp =
if effectful eff
then string "M" ^^ space ^^ parens (doc_typ ctxt ret_typ)
else doc_typ ctxt ret_typ
in
+ let idpp = doc_id id in
+ (* Work around Coq bug 7975 about pattern binders followed by implicit arguments *)
+ let implicitargs =
+ if !used_a_pattern && List.length constrspp + List.length atom_constrs > 0 then
+ break 1 ^^ separate space
+ ([string "Arguments"; idpp;] @
+ List.map (fun _ -> string "{_}") quantspp @
+ List.map (fun _ -> string "_") pats @
+ List.map (fun _ -> string "{_}") constrspp @
+ List.map (fun _ -> string "{_}") atom_constrs)
+ ^^ dot
+ else empty
+ in
let _ = match guard with
| None -> ()
| _ ->
raise (Reporting_basic.err_unreachable l
"guarded pattern expression should have been rewritten before pretty-printing") in
+ let bodypp = doc_fun_body ctxt exp in
+ let bodypp = if effectful eff || not build_ex then bodypp else string "build_ex" ^^ parens bodypp in
group (prefix 3 1
- (separate space [doc_id id; quantspp; patspp; constrspp; atom_constr_pp] ^/^
- colon ^^ space ^^ retpp ^^ coloneq)
- (doc_fun_body ctxt exp ^^ dot))
+ (separate space ([idpp] @ quantspp @ [patspp] @ constrspp @ [atom_constr_pp]) ^/^
+ separate space [colon; retpp; coloneq])
+ (bodypp ^^ dot)) ^^ implicitargs
let get_id = function
| [] -> failwith "FD_function with empty list"
@@ -1658,9 +1844,16 @@ let doc_axiom_typschm typ_env (TypSchm_aux (TypSchm_ts (tqs,typ),l) as ts) =
| _ -> parens (underscore ^^ string " : " ^^ doc_typ empty_ctxt typ)
in
let arg_typs_pp = separate space (List.map doc_typ' typs) in
+ let _, ret_ty = replace_atom_return_type ret_ty in
let ret_typ_pp = doc_typ empty_ctxt ret_ty in
+ let ret_typ_pp =
+ if effectful eff
+ then string "M" ^^ space ^^ parens ret_typ_pp
+ else ret_typ_pp
+ in
let tyvars_pp, constrs_pp = doc_typquant_items_separate empty_ctxt braces tqs in
- string "forall" ^/^ tyvars_pp ^/^ arg_typs_pp ^/^ constrs_pp ^^ comma ^/^ ret_typ_pp
+ string "forall" ^/^ separate space tyvars_pp ^/^
+ arg_typs_pp ^/^ separate space constrs_pp ^^ comma ^/^ ret_typ_pp
| _ -> doc_typschm empty_ctxt true ts
let doc_val_spec unimplemented (VS_aux (VS_val_spec(tys,id,_,_),ann)) =
@@ -1800,6 +1993,8 @@ try
(fun lib -> separate space [string "Require Import";string lib] ^^ dot) defs_modules;hardline;
string "Import ListNotations.";
hardline;
+ string "Open Scope string."; hardline;
+ string "Open Scope bool."; hardline;
(* Put the body into a Section so that we can define some values with
Let to put them into the local context, where tactics can see them *)
string "Section Content.";
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index c872d420..9897bb7c 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -334,14 +334,14 @@ let replace_typ_size ctxt env (Typ_aux (t,a)) =
let mk_typ nexp =
Some (Typ_aux (Typ_app (id, [Typ_arg_aux (Typ_arg_nexp nexp,Parse_ast.Unknown);ord;typ']),a))
in
- match Type_check.solve env size with
- | Some n -> mk_typ (nconstant n)
- | None ->
- let is_equal nexp =
- prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown))
- in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with
- | nexp -> mk_typ nexp
- | exception Not_found -> None
+ let is_equal nexp =
+ prove env (NC_aux (NC_equal (size,nexp),Parse_ast.Unknown))
+ in match List.find is_equal (NexpSet.elements ctxt.bound_nexps) with
+ | nexp -> mk_typ nexp
+ | exception Not_found ->
+ match Type_check.solve env size with
+ | Some n -> mk_typ (nconstant n)
+ | None -> None
end
| _ -> None
diff --git a/src/process_file.ml b/src/process_file.ml
index 9603e986..c3e1b510 100644
--- a/src/process_file.ml
+++ b/src/process_file.ml
@@ -398,12 +398,12 @@ let rewrite rewriters defs =
let rewrite_ast = rewrite [("initial", Rewriter.rewrite_defs)]
let rewrite_undefined bitvectors = rewrite [("undefined", fun x -> Rewrites.rewrite_undefined bitvectors x)]
let rewrite_ast_lem = rewrite Rewrites.rewrite_defs_lem
-let rewrite_ast_coq = rewrite Rewrites.rewrite_defs_lem
+let rewrite_ast_coq = rewrite Rewrites.rewrite_defs_coq
let rewrite_ast_ocaml = rewrite Rewrites.rewrite_defs_ocaml
let rewrite_ast_c ast =
ast
|> rewrite Rewrites.rewrite_defs_c
- |> Constant_fold.rewrite_constant_function_calls
+ |> rewrite [("constant_fold", Constant_fold.rewrite_constant_function_calls)]
let rewrite_ast_interpreter = rewrite Rewrites.rewrite_defs_interpreter
let rewrite_ast_check = rewrite Rewrites.rewrite_defs_check
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 214ca571..246a2670 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -1168,6 +1168,25 @@ let case_exp e t cs =
(* let efr = union_effs (List.map effect_of_pexp ps) in *)
fix_eff_exp (annot_exp (E_case (e,ps)) l env t)
+(* Rewrite guarded patterns into a combination of if-expressions and
+ unguarded pattern matches
+
+ Strategy:
+ - Split clauses into groups where the first pattern subsumes all the
+ following ones
+ - Translate the groups in reverse order, using the next group as a
+ fall-through target, if there is one
+ - Within a group,
+ - translate the sequence of clauses to an if-then-else cascade using the
+ guards as long as the patterns are equivalent modulo substitution, or
+ - recursively translate the remaining clauses to a pattern match if
+ there is a difference in the patterns.
+
+ TODO: Compare this more closely with the algorithm in the CPP'18 paper of
+ Spector-Zabusky et al, who seem to use the opposite grouping and merging
+ strategy to ours: group *mutually exclusive* clauses, and try to merge them
+ into a pattern match first instead of an if-then-else cascade.
+*)
let rewrite_guarded_clauses l cs =
let rec group fallthrough clauses =
let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in
@@ -2285,6 +2304,7 @@ let rewrite_type_def_typs rw_typ rw_typquant rw_typschm (TD_aux (td, annot)) =
let rewrite_dec_spec_typs rw_typ (DEC_aux (ds, annot)) =
match ds with
| DEC_reg (typ, id) -> DEC_aux (DEC_reg (rw_typ typ, id), annot)
+ | DEC_config (id, typ, exp) -> DEC_aux (DEC_config (id, rw_typ typ, exp), annot)
| _ -> assert false
(* Remove overload definitions and cast val specs from the
@@ -2606,7 +2626,7 @@ let rewrite_defs_letbind_effects =
match lexp_aux with
| LEXP_id _ -> k lexp
| LEXP_deref exp ->
- n_exp exp (fun exp ->
+ n_exp_name exp (fun exp ->
k (fix_eff_lexp (LEXP_aux (LEXP_deref exp, annot))))
| LEXP_memory (id,es) ->
n_exp_nameL es (fun es ->
@@ -3305,8 +3325,15 @@ let rewrite_defs_mapping_patterns =
in
pexp_rewriters rewrite_pexp
+let rewrite_lit_lem (L_aux (lit, _)) = match lit with
+ | L_num _ | L_string _ | L_hex _ | L_bin _ | L_real _ -> true
+ | _ -> false
+
+let rewrite_no_strings (L_aux (lit, _)) = match lit with
+ | L_string _ -> false
+ | _ -> true
-let rewrite_defs_pat_lits =
+let rewrite_defs_pat_lits rewrite_lit =
let rewrite_pexp (Pat_aux (pexp_aux, annot) as pexp) =
let guards = ref [] in
let counter = ref 0 in
@@ -3314,7 +3341,7 @@ let rewrite_defs_pat_lits =
let rewrite_pat = function
(* HACK: ignore strings for now *)
| P_lit (L_aux (L_string _, _)) as p_aux, p_annot -> P_aux (p_aux, p_annot)
- | P_lit lit, p_annot ->
+ | P_lit lit, p_annot when rewrite_lit lit ->
let env = env_of_annot p_annot in
let typ = typ_of_annot p_annot in
let id = mk_id ("p" ^ string_of_int !counter ^ "#") in
@@ -3513,11 +3540,12 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
let eff = union_eff_exps [c;e1;e2] in
let v = E_aux (E_if (c,e1,e2), (gen_loc el, Some (env, typ, eff))) in
Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
- | E_case (e1,ps) ->
- (* after rewrite_defs_letbind_effects e1 needs no rewriting *)
+ | E_case (e1,ps) | E_try (e1, ps) ->
+ let is_case = match expaux with E_case _ -> true | _ -> false in
let vars, varpats =
- ps
- |> List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e)
+ (* for E_case, e1 needs no rewriting after rewrite_defs_letbind_effects *)
+ ((if is_case then [] else [e1]) @
+ List.map (fun (Pat_aux ((Pat_exp (_,e)|Pat_when (_,_,e)),_)) -> e) ps)
|> List.map find_updated_vars
|> List.fold_left IdSet.union IdSet.empty
|> IdSet.inter used_vars
@@ -3528,8 +3556,10 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
Pat_aux (Pat_exp (p,rewrite_var_updates e),a)
| Pat_aux (Pat_when (p,g,e),a) ->
Pat_aux (Pat_when (p,g,rewrite_var_updates e),a)) ps in
- Same_vars (E_aux (E_case (e1,ps),annot))
+ let expaux = if is_case then E_case (e1, ps) else E_try (e1, ps) in
+ Same_vars (E_aux (expaux, annot))
else
+ let e1 = if is_case then e1 else rewrite_var_updates (add_vars overwrite e1 vars) in
let rewrite_pexp (Pat_aux (pexp, (l, _))) = match pexp with
| Pat_exp (pat, exp) ->
let exp = rewrite_var_updates (add_vars overwrite exp vars) in
@@ -3538,10 +3568,12 @@ let rec rewrite_var_updates ((E_aux (expaux,((l,_) as annot))) as exp) =
| Pat_when _ ->
raise (Reporting_basic.err_unreachable l
"Guarded patterns should have been rewritten already") in
+ let ps = List.map rewrite_pexp ps in
+ let expaux = if is_case then E_case (e1, ps) else E_try (e1, ps) in
let typ = match ps with
| Pat_aux ((Pat_exp (_,first)|Pat_when (_,_,first)),_) :: _ -> typ_of first
| _ -> unit_typ in
- let v = fix_eff_exp (annot_exp (E_case (e1, List.map rewrite_pexp ps)) pl env typ) in
+ let v = fix_eff_exp (annot_exp expaux pl env typ) in
Added_vars (v, tuple_pat (if overwrite then varpats else pat :: varpats))
| E_assign (lexp,vexp) ->
let mk_id_pat id = match Env.lookup_id id env with
@@ -3969,6 +4001,263 @@ let rewrite_defs_realise_mappings (Defs defs) =
in
Defs (List.map rewrite_def defs |> List.flatten)
+
+(* Rewrite to make all pattern matches in Coq output exhaustive.
+ Assumes that guards, vector patterns, etc have been rewritten already. *)
+
+let opt_coq_warn_nonexhaustive = ref false
+
+module MakeExhaustive =
+struct
+
+type rlit =
+ | RL_unit
+ | RL_zero
+ | RL_one
+ | RL_true
+ | RL_false
+ | RL_inf
+
+let string_of_rlit = function
+ | RL_unit -> "()"
+ | RL_zero -> "bitzero"
+ | RL_one -> "bitone"
+ | RL_true -> "true"
+ | RL_false -> "false"
+ | RL_inf -> "..."
+
+let rlit_of_lit (L_aux (l,_)) =
+ match l with
+ | L_unit -> RL_unit
+ | L_zero -> RL_zero
+ | L_one -> RL_one
+ | L_true -> RL_true
+ | L_false -> RL_false
+ | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> RL_inf
+ | L_undef -> assert false
+
+let inv_rlit_of_lit (L_aux (l,_)) =
+ match l with
+ | L_unit -> []
+ | L_zero -> [RL_one]
+ | L_one -> [RL_zero]
+ | L_true -> [RL_false]
+ | L_false -> [RL_true]
+ | L_num _ | L_hex _ | L_bin _ | L_string _ | L_real _ -> [RL_inf]
+ | L_undef -> assert false
+
+type residual_pattern =
+ | RP_any
+ | RP_lit of rlit
+ | RP_enum of id
+ | RP_app of id * residual_pattern list
+ | RP_tup of residual_pattern list
+ | RP_nil
+ | RP_cons of residual_pattern * residual_pattern
+
+let rec string_of_rp = function
+ | RP_any -> "_"
+ | RP_lit rlit -> string_of_rlit rlit
+ | RP_enum id -> string_of_id id
+ | RP_app (f,args) -> string_of_id f ^ "(" ^ String.concat "," (List.map string_of_rp args) ^ ")"
+ | RP_tup rps -> "(" ^ String.concat "," (List.map string_of_rp rps) ^ ")"
+ | RP_nil -> "[| |]"
+ | RP_cons (rp1,rp2) -> string_of_rp rp1 ^ "::" ^ string_of_rp rp2
+
+type ctx = {
+ env : Env.t;
+ enum_to_rest: (residual_pattern list) Bindings.t;
+ constructor_to_rest: (residual_pattern list) Bindings.t
+}
+
+let make_enum_mappings ids m =
+ IdSet.fold (fun id m ->
+ Bindings.add id
+ (List.map (fun e -> RP_enum e) (IdSet.elements (IdSet.remove id ids))) m)
+ ids
+ m
+
+let make_cstr_mappings env ids m =
+ let ids = IdSet.elements ids in
+ let constructors = List.map
+ (fun id ->
+ let _,ty = Env.get_val_spec id env in
+ let args = match ty with
+ | Typ_aux (Typ_fn (Typ_aux (Typ_tup tys,_),_,_),_) -> List.map (fun _ -> RP_any) tys
+ | _ -> [RP_any]
+ in RP_app (id,args)) ids in
+ let rec aux ids acc l =
+ match ids, l with
+ | [], [] -> m
+ | id::ids, rp::t ->
+ let m = aux ids (acc@[rp]) t in
+ Bindings.add id (acc@t) m
+ | _ -> assert false
+ in aux ids [] constructors
+
+let ctx_from_pattern_completeness_ctx env =
+ let ctx = Env.pattern_completeness_ctx env in
+ { env = env;
+ enum_to_rest = Bindings.fold (fun _ ids m -> make_enum_mappings ids m)
+ ctx.Pattern_completeness.enums Bindings.empty;
+ constructor_to_rest = Bindings.fold (fun _ ids m -> make_cstr_mappings env ids m)
+ ctx.Pattern_completeness.variants Bindings.empty
+ }
+
+let printprefix = ref " "
+
+let rec remove_clause_from_pattern ctx (P_aux (rm_pat,ann)) res_pat =
+ let subpats rm_pats res_pats =
+ let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in
+ let rec aux acc fixed residual =
+ match fixed, residual with
+ | [], [] -> []
+ | (fh::ft), (rh::rt) ->
+ let rt' = aux (acc@[fh]) ft rt in
+ let newr = List.map (fun x -> acc @ (x::ft)) rh in
+ newr @ rt'
+ | _,_ -> assert false (* impossible because we managed map2 above *)
+ in aux [] res_pats res_pats'
+ in
+ let inconsistent () =
+ raise (Reporting_basic.err_unreachable (fst ann)
+ ("Inconsistency during exhaustiveness analysis with " ^
+ string_of_rp res_pat))
+ in
+ (*let _ = print_endline (!printprefix ^ "pat: " ^string_of_pat (P_aux (rm_pat,ann))) in
+ let _ = print_endline (!printprefix ^ "res_pat: " ^string_of_rp res_pat) in
+ let _ = printprefix := " " ^ !printprefix in*)
+ let rp' =
+ match rm_pat with
+ | P_wild -> []
+ | P_id id when (match Env.lookup_id id ctx.env with Unbound | Local _ -> true | _ -> false) -> []
+ | P_lit lit ->
+ (match res_pat with
+ | RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit)
+ | RP_lit RL_inf -> [res_pat]
+ | RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat]
+ | _ -> inconsistent ())
+ | P_as (p,_)
+ | P_typ (_,p)
+ | P_var (p,_)
+ -> remove_clause_from_pattern ctx p res_pat
+ | P_id id ->
+ (match Env.lookup_id id ctx.env with
+ | Enum enum ->
+ (match res_pat with
+ | RP_any -> Bindings.find id ctx.enum_to_rest
+ | RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat]
+ | _ -> inconsistent ())
+ | _ -> assert false)
+ | P_tup rm_pats ->
+ let previous_res_pats =
+ match res_pat with
+ | RP_tup res_pats -> res_pats
+ | RP_any -> List.map (fun _ -> RP_any) rm_pats
+ | _ -> inconsistent ()
+ in
+ let res_pats' = subpats rm_pats previous_res_pats in
+ List.map (fun rps -> RP_tup rps) res_pats'
+ | P_app (id,args) ->
+ (match res_pat with
+ | RP_app (id',residual_args) ->
+ if Id.compare id id' == 0 then
+ let res_pats' = subpats args residual_args in
+ List.map (fun rps -> RP_app (id,rps)) res_pats'
+ else [res_pat]
+ | RP_any ->
+ let res_args = subpats args (List.map (fun _ -> RP_any) args) in
+ (List.map (fun l -> (RP_app (id,l))) res_args) @
+ (Bindings.find id ctx.constructor_to_rest)
+ | _ -> inconsistent ()
+ )
+ | P_list ps ->
+ (match ps with
+ | p1::ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1,P_aux (P_list ptl,ann)),ann)) res_pat
+ | [] ->
+ match res_pat with
+ | RP_any -> [RP_cons (RP_any,RP_any)]
+ | RP_cons _ -> [res_pat]
+ | RP_nil -> []
+ | _ -> inconsistent ())
+ | P_cons (p1,p2) -> begin
+ let rp',rps =
+ match res_pat with
+ | RP_cons (rp1,rp2) -> [], Some [rp1;rp2]
+ | RP_any -> [RP_nil], Some [RP_any;RP_any]
+ | RP_nil -> [RP_nil], None
+ | _ -> inconsistent ()
+ in
+ match rps with
+ | None -> rp'
+ | Some rps ->
+ let res_pats = subpats [p1;p2] rps in
+ rp' @ List.map (function [rp1;rp2] -> RP_cons (rp1,rp2) | _ -> assert false) res_pats
+ end
+ | P_record _ ->
+ raise (Reporting_basic.err_unreachable (fst ann)
+ "Record pattern not supported")
+ | P_vector _
+ | P_vector_concat _
+ | P_string_append _ ->
+ raise (Reporting_basic.err_unreachable (fst ann)
+ "Found pattern that should have been rewritten away in earlier stage")
+
+ (*in let _ = printprefix := String.sub (!printprefix) 0 (String.length !printprefix - 2)
+ in let _ = print_endline (!printprefix ^ "res_pats': " ^ String.concat "; " (List.map string_of_rp rp'))*)
+ in rp'
+
+let process_pexp env =
+ let ctx = ctx_from_pattern_completeness_ctx env in
+ fun rps patexp ->
+ (*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in
+ let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*)
+ match patexp with
+ | Pat_aux (Pat_exp (p,_),_) ->
+ List.concat (List.map (remove_clause_from_pattern ctx p) rps)
+ | Pat_aux (Pat_when _,(l,_)) ->
+ raise (Reporting_basic.err_unreachable l
+ "Guarded pattern should have been rewritten away")
+
+let rewrite_case (e,ann) =
+ match e with
+ | E_case (e1,cases) ->
+ begin
+ let env = env_of_annot ann in
+ let rps = List.fold_left (process_pexp env) [RP_any] cases in
+ match rps with
+ | [] -> E_aux (E_case (e1,cases),ann)
+ | (example::_) ->
+
+ let _ =
+ if !opt_coq_warn_nonexhaustive
+ then Reporting_basic.print_err false false
+ (fst ann) "Non-exhaustive matching" ("Example: " ^ string_of_rp example) in
+
+ let l = Parse_ast.Generated Parse_ast.Unknown in
+ let p = P_aux (P_wild, (l, None)) in
+ let ann' = Some (env, typ_of_annot ann, mk_effect [BE_escape]) in
+ (* TODO: use an expression that specifically indicates a failed pattern match *)
+ let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)),(l,None))),(l,ann')) in
+ E_aux (E_case (e1,cases@[Pat_aux (Pat_exp (p,b),(l,None))]),ann)
+ end
+ | _ -> E_aux (e,ann)
+
+let rewrite =
+ let alg = { id_exp_alg with e_aux = rewrite_case } in
+ rewrite_defs_base
+ { rewrite_exp = (fun _ -> fold_exp alg)
+ ; rewrite_pat = rewrite_pat
+ ; rewrite_let = rewrite_let
+ ; rewrite_lexp = rewrite_lexp
+ ; rewrite_fun = rewrite_fun
+ ; rewrite_def = rewrite_def
+ ; rewrite_defs = rewrite_defs_base
+ }
+
+
+end
+
let recheck_defs defs = fst (Type_error.check initial_env defs)
let remove_mapping_valspecs (Defs defs) =
@@ -3985,7 +4274,7 @@ let rewrite_defs_lem = [
("remove_mapping_valspecs", remove_mapping_valspecs);
("pat_string_append", rewrite_defs_pat_string_append);
("mapping_builtins", rewrite_defs_mapping_patterns);
- ("pat_lits", rewrite_defs_pat_lits);
+ ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem);
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
@@ -4017,13 +4306,52 @@ let rewrite_defs_lem = [
("recheck_defs", recheck_defs)
]
+let rewrite_defs_coq = [
+ ("realise_mappings", rewrite_defs_realise_mappings);
+ ("remove_mapping_valspecs", remove_mapping_valspecs);
+ ("pat_string_append", rewrite_defs_pat_string_append);
+ ("mapping_builtins", rewrite_defs_mapping_patterns);
+ ("pat_lits", rewrite_defs_pat_lits rewrite_lit_lem);
+ ("vector_concat_assignments", rewrite_vector_concat_assignments);
+ ("tuple_assignments", rewrite_tuple_assignments);
+ ("simple_assignments", rewrite_simple_assignments);
+ ("remove_vector_concat", rewrite_defs_remove_vector_concat);
+ ("remove_bitvector_pats", rewrite_defs_remove_bitvector_pats);
+ ("remove_numeral_pats", rewrite_defs_remove_numeral_pats);
+ ("guarded_pats", rewrite_defs_guarded_pats);
+ ("bitvector_exps", rewrite_bitvector_exps);
+ (* ("register_ref_writes", rewrite_register_ref_writes); *)
+ ("nexp_ids", rewrite_defs_nexp_ids);
+ ("fix_val_specs", rewrite_fix_val_specs);
+ ("split_execute", rewrite_split_fun_constr_pats "execute");
+ ("recheck_defs", recheck_defs);
+ ("exp_lift_assign", rewrite_defs_exp_lift_assign);
+ (* ("constraint", rewrite_constraint); *)
+ (* ("remove_assert", rewrite_defs_remove_assert); *)
+ ("top_sort_defs", top_sort_defs);
+ ("trivial_sizeof", rewrite_trivial_sizeof);
+ ("sizeof", rewrite_sizeof);
+ ("early_return", rewrite_defs_early_return);
+ ("make_cases_exhaustive", MakeExhaustive.rewrite);
+ ("fix_val_specs", rewrite_fix_val_specs);
+ ("recheck_defs", recheck_defs);
+ ("remove_blocks", rewrite_defs_remove_blocks);
+ ("letbind_effects", rewrite_defs_letbind_effects);
+ ("remove_e_assign", rewrite_defs_remove_e_assign);
+ ("internal_lets", rewrite_defs_internal_lets);
+ ("remove_superfluous_letbinds", rewrite_defs_remove_superfluous_letbinds);
+ ("remove_superfluous_returns", rewrite_defs_remove_superfluous_returns);
+ ("merge function clauses", merge_funcls);
+ ("recheck_defs", recheck_defs)
+ ]
+
let rewrite_defs_ocaml = [
(* ("undefined", rewrite_undefined); *)
("no_effect_check", (fun defs -> opt_no_effects := true; defs));
("realise_mappings", rewrite_defs_realise_mappings);
("pat_string_append", rewrite_defs_pat_string_append);
("mapping_builtins", rewrite_defs_mapping_patterns);
- ("pat_lits", rewrite_defs_pat_lits);
+ ("pat_lits", rewrite_defs_pat_lits rewrite_no_strings);
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
@@ -4045,7 +4373,7 @@ let rewrite_defs_c = [
("realise_mappings", rewrite_defs_realise_mappings);
("pat_string_append", rewrite_defs_pat_string_append);
("mapping_builtins", rewrite_defs_mapping_patterns);
- ("pat_lits", rewrite_defs_pat_lits);
+ ("pat_lits", rewrite_defs_pat_lits rewrite_no_strings);
("vector_concat_assignments", rewrite_vector_concat_assignments);
("tuple_assignments", rewrite_tuple_assignments);
("simple_assignments", rewrite_simple_assignments);
diff --git a/src/rewrites.mli b/src/rewrites.mli
index 70cb75af..7d6bc0b2 100644
--- a/src/rewrites.mli
+++ b/src/rewrites.mli
@@ -66,6 +66,13 @@ val rewrite_defs_interpreter : (string * (tannot defs -> tannot defs)) list
(* Perform rewrites to exclude AST nodes not supported for lem out*)
val rewrite_defs_lem : (string * (tannot defs -> tannot defs)) list
+(* Perform rewrites to exclude AST nodes not supported for coq out*)
+val rewrite_defs_coq : (string * (tannot defs -> tannot defs)) list
+
+(* Warn about matches where we add a default case for Coq because they're not
+ exhaustive *)
+val opt_coq_warn_nonexhaustive : bool ref
+
(* Perform rewrites to exclude AST nodes not supported for C compilation *)
val rewrite_defs_c : (string * (tannot defs -> tannot defs)) list
diff --git a/src/sail.ml b/src/sail.ml
index 944eb9ff..5b7c9dbf 100644
--- a/src/sail.ml
+++ b/src/sail.ml
@@ -118,6 +118,9 @@ let options = Arg.align ([
( "-Oconstant_fold",
Arg.Set Constant_fold.optimize_constant_fold,
" Apply constant folding optimizations");
+ ( "-static",
+ Arg.Set C_backend.opt_static,
+ " Make generated C functions static");
( "-trace",
Arg.Tuple [Arg.Set C_backend.opt_trace; Arg.Set Ocaml_backend.opt_trace_ocaml],
" Instrument ouput with tracing");
@@ -145,6 +148,9 @@ let options = Arg.align ([
( "-dcoq_undef_axioms",
Arg.Set Pretty_print_coq.opt_undef_axioms,
"Generate axioms for functions that are declared but not defined");
+ ( "-dcoq_warn_nonex",
+ Arg.Set Rewrites.opt_coq_warn_nonexhaustive,
+ "Generate warnings for non-exhaustive pattern matches in the Coq backend");
( "-latex_prefix",
Arg.String (fun prefix -> Latex.opt_prefix_latex := prefix),
" set a custom prefix for generated latex command (default sail)");
diff --git a/src/sail_lib.ml b/src/sail_lib.ml
index ab621342..16b1d3cc 100644
--- a/src/sail_lib.ml
+++ b/src/sail_lib.ml
@@ -419,7 +419,7 @@ let get_mem_page p =
try
Mem.find p !mem_pages
with Not_found ->
- let new_page = Bytes.create page_size_bytes in
+ let new_page = Bytes.make page_size_bytes '\000' in
mem_pages := Mem.add p new_page !mem_pages;
new_page
@@ -457,7 +457,7 @@ let write_ram' (data_size, addr, data) =
end
let write_ram (addr_size, data_size, hex_ram, addr, data) =
- write_ram' (data_size, uint addr, data)
+ write_ram' (data_size, uint addr, data); true
let wram addr byte =
let bytes = Bytes.make 1 (char_of_int byte) in
diff --git a/src/state.ml b/src/state.ml
index c591f753..245d450c 100644
--- a/src/state.ml
+++ b/src/state.ml
@@ -367,15 +367,15 @@ let generate_isa_lemmas mwords (Defs defs : tannot defs) =
let id' = remove_trailing_underscores id in
separate_map hardline string [
"lemma liftS_read_reg_" ^ id ^ "[liftState_simp]:";
- " \"liftS (read_reg " ^ id ^ "_ref) = readS (" ^ id' ^ " \\<circ> regstate)\"";
+ " \"\\<lbrakk>read_reg " ^ id ^ "_ref\\<rbrakk>\\<^sub>S = readS (" ^ id' ^ " \\<circ> regstate)\"";
" by (auto simp: liftState_read_reg_readS register_defs)";
"";
"lemma liftS_write_reg_" ^ id ^ "[liftState_simp]:";
- " \"liftS (write_reg " ^ id ^ "_ref v) = updateS (regstate_update (" ^ id' ^ "_update (\\<lambda>_. v)))\"";
+ " \"\\<lbrakk>write_reg " ^ id ^ "_ref v\\<rbrakk>\\<^sub>S = updateS (regstate_update (" ^ id' ^ "_update (\\<lambda>_. v)))\"";
" by (auto simp: liftState_write_reg_updateS register_defs)"
]
in
- string "abbreviation \"liftS \\<equiv> liftState (get_regval, set_regval)\"" ^^
+ string "abbreviation liftS (\"\\<lbrakk>_\\<rbrakk>\\<^sub>S\") where \"liftS \\<equiv> liftState (get_regval, set_regval)\"" ^^
hardline ^^ hardline ^^
register_defs ^^
hardline ^^ hardline ^^
@@ -411,7 +411,7 @@ let rec regval_convs_coq (Typ_aux (t, _) as typ) = match t with
let size = string_of_nexp (nexp_simp size) in
let is_inc = if is_order_inc ord then "true" else "false" in
let etyp_of, of_etyp = regval_convs_coq etyp in
- "(fun v => vector_of_regval " ^ etyp_of ^ " v)",
+ "(fun v => vector_of_regval " ^ size ^ " " ^ etyp_of ^ " v)",
"(fun v => regval_of_vector " ^ of_etyp ^ " " ^ size ^ " " ^ is_inc ^ " v)"
| Typ_app (id, [Typ_arg_aux (Typ_arg_typ etyp, _)])
when string_of_id id = "list" ->
@@ -430,12 +430,12 @@ let rec regval_convs_coq (Typ_aux (t, _) as typ) = match t with
let register_refs_coq registers =
let generic_convs =
separate_map hardline string [
- "Definition vector_of_regval {a} (of_regval : register_value -> option a) (rv : register_value) : option (list a) := match rv with";
- " | Regval_vector (_, _, v) => just_list (List.map of_regval v)";
+ "Definition vector_of_regval {a} n (of_regval : register_value -> option a) (rv : register_value) : option (vec a n) := match rv with";
+ " | Regval_vector (n', _, v) => if n =? n' then map_bind (vec_of_list n) (just_list (List.map of_regval v)) else None";
" | _ => None";
"end.";
"";
- "Definition regval_of_vector {a} (regval_of : a -> register_value) (size : Z) (is_inc : bool) (xs : list a) : register_value := Regval_vector (size, is_inc, List.map regval_of xs).";
+ "Definition regval_of_vector {a} (regval_of : a -> register_value) (size : Z) (is_inc : bool) (xs : vec a size) : register_value := Regval_vector (size, is_inc, List.map regval_of (list_of_vec xs)).";
"";
"Definition list_of_regval {a} (of_regval : register_value -> option a) (rv : register_value) : option (list a) := match rv with";
" | Regval_list v => just_list (List.map of_regval v)";
diff --git a/src/test/lib/run_test_interp.ml b/src/test/lib/run_test_interp.ml
deleted file mode 100644
index 5f2ace1b..00000000
--- a/src/test/lib/run_test_interp.ml
+++ /dev/null
@@ -1,51 +0,0 @@
-(**************************************************************************)
-(* Sail *)
-(* *)
-(* Copyright (c) 2013-2017 *)
-(* Kathyrn Gray *)
-(* Shaked Flur *)
-(* Stephen Kell *)
-(* Gabriel Kerneis *)
-(* Robert Norton-Wright *)
-(* Christopher Pulte *)
-(* Peter Sewell *)
-(* *)
-(* All rights reserved. *)
-(* *)
-(* This software was developed by the University of Cambridge Computer *)
-(* Laboratory as part of the Rigorous Engineering of Mainstream Systems *)
-(* (REMS) project, funded by EPSRC grant EP/K008528/1. *)
-(* *)
-(* Redistribution and use in source and binary forms, with or without *)
-(* modification, are permitted provided that the following conditions *)
-(* are met: *)
-(* 1. Redistributions of source code must retain the above copyright *)
-(* notice, this list of conditions and the following disclaimer. *)
-(* 2. Redistributions in binary form must reproduce the above copyright *)
-(* notice, this list of conditions and the following disclaimer in *)
-(* the documentation and/or other materials provided with the *)
-(* distribution. *)
-(* *)
-(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *)
-(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *)
-(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *)
-(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *)
-(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *)
-(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *)
-(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *)
-(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *)
-(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *)
-(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *)
-(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *)
-(* SUCH DAMAGE. *)
-(**************************************************************************)
-
-open Interp_interface ;;
-open Interp_inter_imp ;;
-open Sail_impl_base ;;
-
-let doit () =
- let context = build_context false Test_lem_ast.defs [] [] [] [] [] [] [] None [] in
- translate_address context E_little_endian "run" None (address_of_integer (Nat_big_num.of_int 0));;
-
-doit () ;;
diff --git a/src/type_check.ml b/src/type_check.ml
index 810cd265..c73c7000 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -127,28 +127,11 @@ let is_unknown_type = function
| (Typ_aux (Typ_internal_unknown, _)) -> true
| _ -> false
-(* An index_sort is a more general form of range type: it can either
- be IS_int, which represents every natural number, or some set of
- natural numbers given by an IS_prop expression of the form
- {'n. f('n) <= g('n) /\ ...} *)
-type index_sort =
- | IS_int
- | IS_prop of kid * (nexp * nexp) list
-
-let string_of_index_sort = function
- | IS_int -> "INT"
- | IS_prop (kid, constraints) ->
- "{" ^ string_of_kid kid ^ " | "
- ^ string_of_list " & " (fun (x, y) -> string_of_nexp x ^ " <= " ^ string_of_nexp y) constraints
- ^ "}"
-
let is_atom (Typ_aux (typ_aux, _)) =
match typ_aux with
| Typ_app (f, [_]) when string_of_id f = "atom" -> true
| _ -> false
-
-
let rec strip_id = function
| Id_aux (Id x, _) -> Id_aux (Id x, Parse_ast.Unknown)
| Id_aux (DeIid x, _) -> Id_aux (DeIid x, Parse_ast.Unknown)
@@ -367,6 +350,7 @@ module Env : sig
val is_mapping : id -> t -> bool
val add_record : id -> typquant -> (typ * id) list -> t -> t
val is_record : id -> t -> bool
+ val get_record : id -> t -> typquant * (typ * id) list
val get_accessor_fn : id -> id -> t -> typquant * typ
val get_accessor : id -> id -> t -> typquant * typ * typ * effect
val add_local : id -> mut * typ -> t -> t
@@ -929,6 +913,8 @@ end = struct
let is_record id env = Bindings.mem id env.records
+ let get_record id env = Bindings.find id env.records
+
let add_record id typq fields env =
if bound_typ_id env id
then typ_error (id_loc id) ("Cannot create record " ^ string_of_id id ^ ", type name is already bound")
@@ -3311,6 +3297,12 @@ and instantiation_of (E_aux (exp_aux, (l, _)) as exp) =
| E_app (f, xs) -> snd (infer_funapp' l (Env.no_casts env) f (Env.get_val_spec f env) (List.map strip_exp xs) (Some (typ_of exp)))
| _ -> invalid_arg ("instantiation_of expected application, got " ^ string_of_exp exp)
+and instantiation_of_without_type (E_aux (exp_aux, (l, _)) as exp) =
+ let env = env_of exp in
+ match exp_aux with
+ | E_app (f, xs) -> snd (infer_funapp' l (Env.no_casts env) f (Env.get_val_spec f env) (List.map strip_exp xs) None)
+ | _ -> invalid_arg ("instantiation_of expected application, got " ^ string_of_exp exp)
+
and infer_funapp' l env f (typq, f_typ) xs ret_ctx_typ =
let annot_exp exp typ eff = E_aux (exp, (l, Some (env, typ, eff))) in
let switch_annot env typ = function
diff --git a/src/type_check.mli b/src/type_check.mli
index 24443f2e..286c2be9 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -133,6 +133,9 @@ module Env : sig
val is_record : id -> t -> bool
+ (** Returns record quantifiers and fields *)
+ val get_record : id -> t -> typquant * (typ * id) list
+
(** Return type is: quantifier, argument type, return type, effect *)
val get_accessor : id -> id -> t -> typquant * typ * typ * effect
@@ -191,6 +194,7 @@ module Env : sig
val empty : t
+ val pattern_completeness_ctx : t -> Pattern_completeness.ctx
end
(** Push all the type variables and constraints from a typquant into
@@ -349,6 +353,10 @@ val alpha_equivalent : Env.t -> typ -> typ -> bool
(** Throws Invalid_argument if the argument is not a E_app expression *)
val instantiation_of : tannot exp -> uvar KBindings.t
+(** Doesn't use the type of the expression when calculating instantiations.
+ May fail if the arguments aren't sufficient to calculate all unifiers. *)
+val instantiation_of_without_type : tannot exp -> uvar KBindings.t
+
(* Type variable instantiations that inference will extract from constraints *)
val instantiate_simple_equations : quant_item list -> uvar KBindings.t
diff --git a/src/value.ml b/src/value.ml
index 41b52720..dccb216e 100644
--- a/src/value.ml
+++ b/src/value.ml
@@ -406,8 +406,8 @@ let value_read_ram = function
let value_write_ram = function
| [v1; v2; v3; v4; v5] ->
- Sail_lib.write_ram (coerce_int v1, coerce_int v2, coerce_bv v3, coerce_bv v4, coerce_bv v5);
- V_unit
+ let b = Sail_lib.write_ram (coerce_int v1, coerce_int v2, coerce_bv v3, coerce_bv v4, coerce_bv v5) in
+ V_bool(b)
| _ -> failwith "value write_ram"
let value_load_raw = function
@@ -561,6 +561,7 @@ let primops =
("write_ram", value_write_ram);
("trace_memory_read", fun _ -> V_unit);
("trace_memory_write", fun _ -> V_unit);
+ ("get_time_ns", fun _ -> V_int (Sail_lib.get_time_ns()));
("load_raw", value_load_raw);
("to_real", value_to_real);
("eq_real", value_eq_real);
diff --git a/src/value2.lem b/src/value2.lem
index 33416503..e8a8262a 100644
--- a/src/value2.lem
+++ b/src/value2.lem
@@ -1,4 +1,4 @@
-(**************************************************************************)
+(*========================================================================*)
(* Sail *)
(* *)
(* Copyright (c) 2013-2017 *)
@@ -46,7 +46,7 @@
(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *)
(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *)
(* SUCH DAMAGE. *)
-(**************************************************************************)
+(*========================================================================*)
open import Pervasives
open import Assert_extra
@@ -70,17 +70,13 @@ type vl =
| V_record of list (string * vl)
| V_null (* Used for unitialized values and null pointers in C compilation *)
-let primops extern args =
- match (extern, args) with
- | ("and_bool", [V_bool b1; V_bool b2]) -> V_bool (b1 && b2)
- | ("and_bool", [V_nondet; V_bool false]) -> V_bool false
- | ("and_bool", [V_bool false; V_nondet]) -> V_bool false
- | ("and_bool", _) -> V_nondet
- | ("or_bool", [V_bool b1; V_bool b2]) -> V_bool (b1 || b2)
- | ("or_bool", [V_nondet; V_bool true]) -> V_bool true
- | ("or_bool", [V_bool true; V_nondet]) -> V_bool true
- | ("or_bool", _) -> V_nondet
+let value_int_op_int op = function
+ | [V_int v1; V_int v2] -> V_int (op v1 v2)
+ | _ -> V_null
+end
- | _ -> failwith ("Unimplemented primitive operation " ^ extern)
- end
+let value_bool_op_int op = function
+ | [V_int v1; V_int v2] -> V_bool (op v1 v2)
+ | _ -> V_null
+end