diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/jib/c_backend.ml | 6 | ||||
| -rw-r--r-- | src/jib/c_codegen.ml | 187 | ||||
| -rw-r--r-- | src/jib/jib_compile.ml | 12 | ||||
| -rw-r--r-- | src/jib/jib_compile.mli | 2 | ||||
| -rw-r--r-- | src/jib/jib_smt.ml | 2 | ||||
| -rw-r--r-- | src/rewrites.ml | 16 | ||||
| -rw-r--r-- | src/rewrites.mli | 1 | ||||
| -rw-r--r-- | src/type_check.ml | 53 |
8 files changed, 218 insertions, 61 deletions
diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index 9dedb830..b2447125 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -320,21 +320,21 @@ module C_config(Opts : sig val branch_coverage : out_channel option end) : Confi (* We need to check that id's type hasn't changed due to flow typing *) let _, ctyp' = Bindings.find id ctx.locals in if ctyp_equal ctyp ctyp' then - AV_cval (V_id (name id, ctyp), typ) + AV_cval (V_id (name_or_global ctx id, ctyp), typ) else (* id's type changed due to flow typing, so it's really still heap allocated! *) v with (* Hack: Assuming global letbindings don't change from flow typing... *) - Not_found -> AV_cval (V_id (name id, ctyp), typ) + Not_found -> AV_cval (V_id (name_or_global ctx id, ctyp), typ) end else v | Register (_, _, typ) -> let ctyp = convert_typ ctx typ in if is_stack_ctyp ctyp && not (never_optimize ctyp) then - AV_cval (V_id (name id, ctyp), typ) + AV_cval (V_id (global id, ctyp), typ) else v | _ -> v diff --git a/src/jib/c_codegen.ml b/src/jib/c_codegen.ml index d5015318..6ff27303 100644 --- a/src/jib/c_codegen.ml +++ b/src/jib/c_codegen.ml @@ -96,6 +96,10 @@ let rec is_stack_ctyp ctyp = match ctyp with | CT_poly -> true | CT_constant n -> Big_int.less_equal (min_int 64) n && Big_int.greater_equal n (max_int 64) +(** For now, the types that can be used in the state API are the types that fits on the stack. + In the future, this can be expanded to support more complex types if needed. *) +let is_api_ctyp = is_stack_ctyp + let v_mask_lower i = V_lit (VL_bits (Util.list_init i (fun _ -> Sail2_values.B1), true), CT_fbits (i, true)) type codegen_options = { @@ -230,12 +234,15 @@ let sgen_id id = | Some export -> export | None -> mangle (string_of_id id) -let sgen_name = - function +let sgen_name ctyp id = + match id with | Name (id, _) -> sgen_id id | Global (id, _) -> - sprintf "(state->%s)" (sgen_id id) + if is_api_ctyp ctyp then + sprintf "(state_get_%s(state))" (sgen_id id) + else + sprintf "(state->%s)" (sgen_id id) | Have_exception _ -> "(state->have_exception)" | Return _ -> @@ -343,11 +350,15 @@ let rec sgen_value = function let rec sgen_cval ctx = function | V_id (id, ctyp) -> - sgen_name id + sgen_name ctyp id | V_lit (vl, ctyp) -> sgen_value vl | V_call (op, cvals) -> sgen_call ctx op cvals - | V_field (f, field) -> - sprintf "%s.%s" (sgen_cval ctx f) (sgen_uid field) + | V_field (f, field) -> + begin match f with + | V_id (Global (id, _), ctyp) when is_api_ctyp ctyp -> + sprintf "state_get_%s_in_%s(state)" (sgen_uid field) (sgen_id id) + | _ -> sprintf "%s.%s" (sgen_cval ctx f) (sgen_uid field) + end | V_tuple_member (f, _, n) -> sprintf "%s.%s" (sgen_cval ctx f) (mangle ("tup" ^ string_of_int n)) | V_ctor_kind (f, ctor, unifiers, _) -> @@ -532,25 +543,70 @@ let sgen_cval_param ctx cval = | _ -> sgen_cval ctx cval -let rec sgen_clexp ctx = function +let rec sgen_clexp_state_api = function + | CL_id (Global (id, _), _) -> sgen_id id + | CL_field (clexp, field) -> sgen_uid field ^ "_in_" ^ sgen_clexp_state_api clexp + | _ -> assert false + +let sgen_cval_state_api = function + | V_id (Global (id, _), ctyp) -> sgen_id id + | _ -> assert false + +let rec is_state_api_cval = function + | V_id (Global (id, _), ctyp) -> + begin match ctyp with + | CT_vector (_, vctyp) -> is_api_ctyp vctyp + | CT_fvector (_, _, vctyp) -> is_api_ctyp vctyp + | _ when is_api_ctyp ctyp -> true + | _ -> false + end + | V_field (clexp, field) -> is_state_api_cval clexp + | _ -> false + +let rec is_state_api_clexp = function + | CL_id (Global (id, _), ctyp) -> + begin match ctyp with + | CT_vector (_, vctyp) -> is_api_ctyp vctyp + | CT_fvector (_, _, vctyp) -> is_api_ctyp vctyp + | _ when is_api_ctyp ctyp -> true + | _ -> false + end + | CL_field (clexp, field) -> is_state_api_clexp clexp + | _ -> false + +let rec sgen_clexp ctx clexp = + match clexp with | CL_id (Have_exception _, _) -> "(state->have_exception)" | CL_id (Current_exception _, _) -> "(state->current_exception)" | CL_id (Throw_location _, _) -> "(state->throw_location)" | CL_id (Return _, _) -> assert false - | CL_id (Global (id, _), _) -> "&(state->" ^ sgen_id id ^ ")" + | CL_id (Global (id, _), _) -> + let ctyp = clexp_ctyp clexp in + if is_api_ctyp ctyp then + "(state_get_" ^ sgen_id id ^ "(state))" + else + "&(state->" ^ sgen_id id ^ ")" | CL_id (Name (id, _), _) -> "&" ^ sgen_id id - | CL_field (clexp, field) -> "&((" ^ sgen_clexp ctx clexp ^ ")->" ^ sgen_uid field ^ ")" + | CL_field (clexp, field) -> + begin match clexp with + | CL_id (Global (id, _), ctyp) when is_api_ctyp ctyp -> "(state_get_ " ^ sgen_uid field ^ "_in_" ^ sgen_clexp_state_api clexp ^ "(state))" + | _ -> "&((" ^ sgen_clexp ctx clexp ^ ")->" ^ sgen_uid field ^ ")" + end | CL_tuple (clexp, n) -> "&((" ^ sgen_clexp ctx clexp ^ ")->" ^ mangle ("tup" ^ string_of_int n) ^ ")" | CL_addr clexp -> "(*(" ^ sgen_clexp ctx clexp ^ "))" | CL_void -> assert false | CL_rmw _ -> assert false -let rec sgen_clexp_pure ctx = function +let rec sgen_clexp_pure ctx clexp = + match clexp with | CL_id (Have_exception _, _) -> "(state->have_exception)" | CL_id (Current_exception _, _) -> "(state->current_exception)" | CL_id (Throw_location _, _) -> "(state->throw_location)" | CL_id (Return _, _) -> assert false - | CL_id (Global (id, _), _) -> "state->" ^ sgen_id id + | CL_id (Global (id, _), _) -> + let ctyp = clexp_ctyp clexp in + assert (not (is_api_ctyp ctyp)); + "state->" ^ sgen_id id | CL_id (Name (id, _), _) -> sgen_id id | CL_field (clexp, field) -> sgen_clexp_pure ctx clexp ^ "." ^ sgen_uid field | CL_tuple (clexp, n) -> sgen_clexp_pure ctx clexp ^ "." ^ mangle ("tup" ^ string_of_int n) @@ -569,7 +625,10 @@ let rec codegen_conversion l ctx clexp cval = (* When both types are equal, we don't need any conversion. *) | _, _ when ctyp_equal ctyp_to ctyp_from -> if is_stack_ctyp ctyp_to then - ksprintf string " %s = %s;" (sgen_clexp_pure ctx clexp) (sgen_cval ctx cval) + if is_state_api_clexp clexp then + ksprintf string " state_set_%s(state, %s);" (sgen_clexp_state_api clexp) (sgen_cval ctx cval) + else + ksprintf string " %s = %s;" (sgen_clexp_pure ctx clexp) (sgen_cval ctx cval) else ksprintf string " COPY(%s)(%s, %s);" (sgen_ctyp_name ctyp_to) (sgen_clexp ctx clexp) (sgen_cval ctx cval) @@ -590,6 +649,10 @@ let rec codegen_conversion l ctx clexp cval = (* For anything not special cased, just try to call a appropriate CONVERT_OF function. *) | _, _ when is_stack_ctyp (clexp_ctyp clexp) -> + if is_state_api_clexp clexp then + ksprintf string " state_set_%s(state, CONVERT_OF(%s, %s)(%s));" + (sgen_clexp_state_api clexp) (sgen_ctyp_name ctyp_to) (sgen_ctyp_name ctyp_from) (sgen_cval_param ctx cval) + else ksprintf string " %s = CONVERT_OF(%s, %s)(%s);" (sgen_clexp_pure ctx clexp) (sgen_ctyp_name ctyp_to) (sgen_ctyp_name ctyp_from) (sgen_cval_param ctx cval) | _, _ -> @@ -607,10 +670,10 @@ let extra_arguments is_extern = let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = match instr with | I_decl (ctyp, id) when is_stack_ctyp ctyp -> - ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name id) + ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name ctyp id) | I_decl (ctyp, id) -> - ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name id) ^^ hardline - ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id) + ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name ctyp id) ^^ hardline + ^^ ksprintf string " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name ctyp id) | I_copy (clexp, cval) -> codegen_conversion l ctx clexp cval @@ -650,11 +713,11 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = else (sgen_function_uid f, false) in - let sail_state_arg = - if is_extern && StringSet.mem fname O.opts.state_primops then + let sail_state_arg : string ref = ref + (if is_extern && StringSet.mem fname O.opts.state_primops then "sail_state *state, " else - "" + "") in let fname = match fname, ctyp with @@ -677,6 +740,7 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = | "vector_access", CT_bit -> "bitvector_access" | "vector_access", _ -> begin match args with + | cval :: _ when is_state_api_cval cval -> sail_state_arg := "state, "; sprintf "state_vector_access_%s" (sgen_cval_state_api cval) | cval :: _ -> sprintf "vector_access_%s" (sgen_ctyp_name (cval_ctyp cval)) | _ -> Reporting.unreachable l __POS__ "vector access function with bad arity." end @@ -684,7 +748,11 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = | "vector_subrange", _ -> sprintf "vector_subrange_%s" (sgen_ctyp_name ctyp) | "vector_update", CT_fbits _ -> "update_fbits" | "vector_update", CT_lbits _ -> "update_lbits" - | "vector_update", _ -> sprintf "vector_update_%s" (sgen_ctyp_name ctyp) + | "vector_update", _ -> if is_state_api_clexp x then ( + sail_state_arg := "state, "; + sprintf "state_vector_update_%s" (sgen_clexp_state_api x)) + else + sprintf "vector_update_%s" (sgen_ctyp_name ctyp) | "string_of_bits", _ -> begin match cval_ctyp (List.nth args 0) with | CT_fbits _ -> "string_of_fbits" @@ -713,14 +781,17 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = ksprintf string " COPY(%s)(&%s, *(%s));" (sgen_ctyp_name ctyp) (sgen_clexp_pure ctx x) c_args else if is_stack_ctyp ctyp then - ksprintf string " %s = %s(%s%s%s);" (sgen_clexp_pure ctx x) fname sail_state_arg (extra_arguments is_extern) c_args + if is_state_api_clexp x then + ksprintf string " state_set_%s(state, %s(%s%s%s));" (sgen_clexp_state_api x) fname !sail_state_arg (extra_arguments is_extern) c_args + else + ksprintf string " %s = %s(%s%s%s);" (sgen_clexp_pure ctx x) fname !sail_state_arg (extra_arguments is_extern) c_args else - ksprintf string " %s(%s%s%s, %s);" fname sail_state_arg (extra_arguments is_extern) (sgen_clexp ctx x) c_args + ksprintf string " %s(%s%s%s, %s);" fname !sail_state_arg (extra_arguments is_extern) (sgen_clexp ctx x) c_args | I_clear (ctyp, id) when is_stack_ctyp ctyp -> empty | I_clear (ctyp, id) -> - ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id) + ksprintf string " KILL(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name ctyp id) | I_init (ctyp, id, cval) -> codegen_instr fid ctx (idecl ctyp id) ^^ hardline @@ -731,9 +802,9 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = ^^ codegen_conversion Parse_ast.Unknown ctx (CL_id (id, ctyp)) cval | I_reset (ctyp, id) when is_stack_ctyp ctyp -> - ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name id) + ksprintf string " %s %s;" (sgen_ctyp ctyp) (sgen_name ctyp id) | I_reset (ctyp, id) -> - ksprintf string " RECREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id) + ksprintf string " RECREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name ctyp id) | I_return cval -> ksprintf string " return %s;" (sgen_cval ctx cval) @@ -760,8 +831,8 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = sprintf ".%s = %s" (mangle ("tup" ^ string_of_int n)) init :: inits, prev @ prev' in let inits, prev = List.fold_left fold ([], []) (List.mapi (fun i x -> (i, x)) ctyps) in - sgen_name gs, - [sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name gs) + sgen_name ctyp gs, + [sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name ctyp gs) ^ Util.string_of_list ", " (fun x -> x) inits ^ " };"] @ prev | CT_struct (id, ctors) when is_stack_ctyp ctyp -> let gs = ngensym () in @@ -770,8 +841,8 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = sprintf ".%s = %s" (sgen_uid uid) init :: inits, prev @ prev' in let inits, prev = List.fold_left fold ([], []) ctors in - sgen_name gs, - [sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name gs) + sgen_name ctyp gs, + [sprintf "struct %s %s = { " (sgen_ctyp_name ctyp) (sgen_name ctyp gs) ^ Util.string_of_list ", " (fun x -> x) inits ^ " };"] @ prev | ctyp -> Reporting.unreachable l __POS__ ("Cannot create undefined value for type: " ^ string_of_ctyp ctyp) in @@ -1154,7 +1225,13 @@ let codegen_vector_header ctx id (direction, ctyp) = string (Printf.sprintf "struct %s {\n size_t len;\n %s *data;\n};\n" (sgen_id id) (sgen_ctyp ctyp)) ^^ string (Printf.sprintf "typedef struct %s %s;" (sgen_id id) (sgen_id id)) in - vector_typedef ^^ twice hardline + vector_typedef ^^ hardline ^^ + string (Printf.sprintf "void vector_update_%s(%s *rop, %s op, sail_int n, %s elem);" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) ^^ hardline ^^ + if is_stack_ctyp ctyp then + string (Printf.sprintf "%s vector_access_%s(%s op, sail_int n);" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) + else + string (Printf.sprintf "static void vector_access_%s(%s *rop, %s op, sail_int n);" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id)) + ^^ twice hardline let codegen_vector_body ctx id (direction, ctyp) = let vector_init = @@ -1184,7 +1261,7 @@ let codegen_vector_body ctx id (direction, ctyp) = ^^ string "}" in let vector_update = - string (Printf.sprintf "static void vector_update_%s(%s *rop, %s op, sail_int n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) + string (Printf.sprintf "void vector_update_%s(%s *rop, %s op, sail_int n, %s elem) {\n" (sgen_id id) (sgen_id id) (sgen_id id) (sgen_ctyp ctyp)) ^^ string " int m = sail_int_get_ui(n);\n" ^^ string " if (rop->data == op.data) {\n" ^^ string (if is_stack_ctyp ctyp then @@ -1210,7 +1287,7 @@ let codegen_vector_body ctx id (direction, ctyp) = in let vector_access = if is_stack_ctyp ctyp then - string (Printf.sprintf "static %s vector_access_%s(%s op, sail_int n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) + string (Printf.sprintf "%s vector_access_%s(%s op, sail_int n) {\n" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id)) ^^ string " int m = sail_int_get_ui(n);\n" ^^ string " return op.data[m];\n" ^^ string "}" @@ -1291,13 +1368,13 @@ let is_decl = function let codegen_decl = function | I_aux (I_decl (ctyp, id), _) -> - string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_name id)) + string (Printf.sprintf "%s %s;" (sgen_ctyp ctyp) (sgen_name ctyp id)) | _ -> assert false let codegen_alloc = function | I_aux (I_decl (ctyp, id), _) when is_stack_ctyp ctyp -> empty | I_aux (I_decl (ctyp, id), _) -> - string (Printf.sprintf " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name id)) + string (Printf.sprintf " CREATE(%s)(&%s);" (sgen_ctyp_name ctyp) (sgen_name ctyp id)) | _ -> assert false let add_local_labels instrs = @@ -1465,6 +1542,47 @@ let codegen_state_struct_def ctx = function | _ -> empty +let codegen_state_api_struct id ctyp = + match ctyp with + | CT_struct (_, fields) -> + let codegen_state_api_field ((fid, _), ctyp) = + if is_api_ctyp ctyp then + ksprintf string "static inline void state_set_%s_in_%s(sail_state* st, %s u) { st->%s.%s = u; }" + (sgen_id fid) (sgen_id id) (sgen_ctyp ctyp) (sgen_id id) (sgen_id fid) + ^^ hardline + ^^ ksprintf string "static inline %s state_get_%s_in_%s(sail_state* st) { return st->%s.%s; }" + (sgen_ctyp ctyp) (sgen_id fid) (sgen_id id) (sgen_id id) (sgen_id fid) + else + empty + in + separate_map hardline codegen_state_api_field fields ^^ hardline + | _ -> empty + +let codegen_state_api_vector id ctyp vctyp = + string (Printf.sprintf "static inline void state_vector_update_%s(sail_state* st, %s *rop, %s op, sail_int n, %s elem) { vector_update_%s(rop, op, n, elem); }" + (sgen_id id) (sgen_ctyp ctyp) (sgen_ctyp ctyp) (sgen_ctyp vctyp) (sgen_ctyp ctyp)) ^^ hardline ^^ + string (Printf.sprintf "static inline %s state_vector_access_%s(sail_state* st, %s op, sail_int n) { return vector_access_%s(op, n); }" + (sgen_ctyp vctyp) (sgen_id id) (sgen_ctyp ctyp) (sgen_ctyp ctyp)) ^^ hardline + +let codegen_state_api_reg_dec id ctyp = + begin match ctyp with + | _ when is_api_ctyp ctyp -> + ksprintf string "static inline %s state_get_%s(sail_state* st) { return st->%s; }" (sgen_ctyp ctyp) (sgen_id id) (sgen_id id) ^^ hardline ^^ + ksprintf string "static inline void state_set_%s(sail_state* st, %s n) { st->%s = n; }" (sgen_id id) (sgen_ctyp ctyp) (sgen_id id) ^^ hardline ^^ + codegen_state_api_struct id ctyp + | CT_vector (_, vctyp) when is_api_ctyp vctyp -> codegen_state_api_vector id ctyp vctyp + | CT_fvector (_, _, vctyp) when is_api_ctyp vctyp -> codegen_state_api_vector id ctyp vctyp + | _ -> empty + end + +let codegen_state_api ctx = function +| CDEF_reg_dec (id, ctyp, _) -> codegen_state_api_reg_dec id ctyp +| CDEF_let (_, [], _) -> empty +| CDEF_let (_, bindings, _) -> + separate_map hardline (fun (id, ctyp) -> codegen_state_api_reg_dec id ctyp) bindings + ^^ hardline +| _ -> empty + let codegen_state_struct ctx cdefs = string "struct sail_state {" ^^ hardline ^^ concat_map (codegen_state_struct_def ctx) cdefs @@ -1477,7 +1595,8 @@ let codegen_state_struct ctx cdefs = ^^ string " sail_string *throw_location;" ^^ hardline )) ^^ concat_map (fun str -> string (" " ^ str) ^^ hardline) O.opts.extra_state - ^^ string "};" + ^^ string "};" ^^ hardline ^^ hardline + ^^ concat_map (codegen_state_api ctx) cdefs let is_cdef_startup = function | CDEF_startup _ -> true diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index 5e234eae..5bf53009 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -182,6 +182,12 @@ module type Config = sig val track_throw : bool end +let name_or_global ctx id = + if Env.is_register id ctx.local_env || IdSet.mem id (Env.get_toplevel_lets ctx.local_env) then + global id + else + name id + module Make(C: Config) = struct let ctyp_of_typ ctx typ = C.convert_typ ctx typ @@ -191,12 +197,6 @@ let rec chunkify n xs = | xs, [] -> [xs] | xs, ys -> xs :: chunkify n ys -let name_or_global ctx id = - if Env.is_register id ctx.local_env || IdSet.mem id (Env.get_toplevel_lets ctx.local_env) then - global id - else - name id - let coverage_branch_count = ref 0 let coverage_loc_args l = diff --git a/src/jib/jib_compile.mli b/src/jib/jib_compile.mli index 3756e58a..30f379d8 100644 --- a/src/jib/jib_compile.mli +++ b/src/jib/jib_compile.mli @@ -134,3 +134,5 @@ end convert several Sail language features, these are sail_assert, sail_exit, and sail_cons. *) val add_special_functions : Env.t -> Env.t + +val name_or_global : ctx -> id -> name diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index 359704b0..6cbd1b87 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -1086,7 +1086,7 @@ let builtin_count_leading_zeros ctx v ret_ctyp = bvint ret_sz (Big_int.zero)) else ( assert (sz land (sz - 1) = 0); - let hsz = sz /2 in + let hsz = sz / 2 in Ite (Fn ("=", [Extract (sz - 1, hsz, smt); bvzero hsz]), Fn ("bvadd", [bvint ret_sz (Big_int.of_int hsz); lzcnt hsz (Extract (hsz - 1, 0, smt))]), lzcnt hsz (Extract (sz - 1, hsz, smt))) diff --git a/src/rewrites.ml b/src/rewrites.ml index 03a70730..b84f328b 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -4878,6 +4878,11 @@ let if_mwords f env defs = let if_mwords_env f env defs = if !Pretty_print_lem.opt_mwords then f env defs else if_mono_env f env defs +let if_flag flag f env defs = + if !flag then f env defs else defs +let if_flag_env flag f env defs = + if !flag then f env defs else defs, env + type rewriter = | Basic_rewriter of (Env.t -> tannot defs -> tannot defs) | Checking_rewriter of (Env.t -> tannot defs -> tannot defs * Env.t) @@ -4888,6 +4893,7 @@ type rewriter = type rewriter_arg = | If_mono_arg | If_mwords_arg + | If_flag of bool ref | Bool_arg of bool | String_arg of string | Literal_arg of string @@ -4904,8 +4910,10 @@ let instantiate_rewrite rewriter args = match rewriter, arg with | Basic_rewriter rw, If_mono_arg -> Basic_rewriter (if_mono rw) | Basic_rewriter rw, If_mwords_arg -> Basic_rewriter (if_mwords rw) + | Basic_rewriter rw, If_flag flag -> Basic_rewriter (if_flag flag rw) | Checking_rewriter rw, If_mono_arg -> Checking_rewriter (if_mono_env rw) | Checking_rewriter rw, If_mwords_arg -> Checking_rewriter (if_mwords_env rw) + | Checking_rewriter rw, If_flag flag -> Checking_rewriter (if_flag_env flag rw) | Bool_rewriter rw, Bool_arg b -> rw b | String_rewriter rw, String_arg str -> rw str | Literal_rewriter rw, Literal_arg selector -> rw (selector_function selector) @@ -4979,8 +4987,8 @@ let rewrites_lem = [ ("toplevel_string_append", []); ("pat_string_append", []); ("mapping_builtins", []); - ("mono_rewrites", []); - ("recheck_defs", [If_mono_arg]); + ("mono_rewrites", [If_flag opt_mono_rewrites]); + ("recheck_defs", [If_flag opt_mono_rewrites]); ("undefined", [Bool_arg false]); ("toplevel_consts", [String_arg "lem"; If_mwords_arg]); ("toplevel_nexps", [If_mono_arg]); @@ -5103,8 +5111,8 @@ let rewrites_c = [ ("toplevel_string_append", []); ("pat_string_append", []); ("mapping_builtins", []); - ("mono_rewrites", [If_mono_arg]); - ("recheck_defs", [If_mono_arg]); + ("mono_rewrites", [If_flag opt_mono_rewrites]); + ("recheck_defs", [If_flag opt_mono_rewrites]); ("toplevel_nexps", [If_mono_arg]); ("monomorphise", [String_arg "c"; If_mono_arg]); ("atoms_to_singletons", [String_arg "c"; If_mono_arg]); diff --git a/src/rewrites.mli b/src/rewrites.mli index 3b572d51..43a2e057 100644 --- a/src/rewrites.mli +++ b/src/rewrites.mli @@ -85,6 +85,7 @@ val rewrite_lit_lem : lit -> bool type rewriter_arg = | If_mono_arg | If_mwords_arg + | If_flag of bool ref | Bool_arg of bool | String_arg of string | Literal_arg of string diff --git a/src/type_check.ml b/src/type_check.ml index 4f0d90bc..1d6566ef 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -590,6 +590,10 @@ end = struct || Bindings.mem id env.enums || Bindings.mem id builtin_typs + let bound_ctor_fn env id = + Bindings.mem id env.top_val_specs + || Bindings.mem id env.union_ids + let get_overloads id env = try Bindings.find id env.overloads with | Not_found -> [] @@ -941,9 +945,14 @@ end = struct | Not_found -> typ_error env (id_loc id) ("No val spec found for " ^ string_of_id id) let add_union_id id bind env = - typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind)); - { env with union_ids = Bindings.add id bind env.union_ids } - + if bound_ctor_fn env id + then typ_error env (id_loc id) ("A union constructor or function already exists with name " ^ string_of_id id ) + else + begin + typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind)); + { env with union_ids = Bindings.add id bind env.union_ids } + end + let get_union_id id env = try let bind = Bindings.find id env.union_ids in @@ -1147,19 +1156,29 @@ end = struct let add_toplevel_lets ids env = { env with top_letbinds = IdSet.union ids env.top_letbinds } - + let get_toplevel_lets env = env.top_letbinds let add_variant id variant env = - typ_print (lazy (adding ^ "variant " ^ string_of_id id)); - { env with variants = Bindings.add id variant env.variants } - + if bound_typ_id env id + then typ_error env (id_loc id) ("Cannot create variant " ^ string_of_id id ^ ", type name is already bound") + else + begin + typ_print (lazy (adding ^ "variant " ^ string_of_id id)); + { env with variants = Bindings.add id variant env.variants } + end + let add_scattered_variant id typq env = - typ_print (lazy (adding ^ "scattered variant " ^ string_of_id id)); - { env with - variants = Bindings.add id (typq, []) env.variants; - scattered_variant_envs = Bindings.add id env env.scattered_variant_envs - } + if bound_typ_id env id + then typ_error env (id_loc id) ("Cannot create scattered variant " ^ string_of_id id ^ ", type name is already bound") + else + begin + typ_print (lazy (adding ^ "scattered variant " ^ string_of_id id)); + { env with + variants = Bindings.add id (typq, []) env.variants; + scattered_variant_envs = Bindings.add id env env.scattered_variant_envs + } + end let add_variant_clause id tu env = match Bindings.find_opt id env.variants with @@ -4890,11 +4909,18 @@ and propagate_lexp_effect_aux = function (* 7. Checking toplevel definitions *) (**************************************************************************) +let check_duplicate_letbinding l pat env = + match IdSet.choose_opt (IdSet.inter (pat_ids pat) (Env.get_toplevel_lets env)) with + | Some id -> + typ_error env l ("Duplicate toplevel let binding " ^ string_of_id id) + | None -> () + let check_letdef orig_env (LB_aux (letbind, (l, _))) = - typ_print (lazy "\nChecking top-level let"); + typ_print (lazy ("\nChecking top-level let" |> cyan |> clear)); begin match letbind with | LB_val (P_aux (P_typ (typ_annot, _), _) as pat, bind) -> + check_duplicate_letbinding l pat orig_env; let checked_bind = propagate_exp_effect (crule check_exp orig_env (strip_exp bind) typ_annot) in let tpat, env = bind_pat_no_guard orig_env (strip_pat pat) typ_annot in if (BESet.is_empty (effect_set (effect_of checked_bind)) || !opt_no_effects) @@ -4902,6 +4928,7 @@ let check_letdef orig_env (LB_aux (letbind, (l, _))) = [DEF_val (LB_aux (LB_val (tpat, checked_bind), (l, None)))], Env.add_toplevel_lets (pat_ids tpat) env else typ_error env l ("Top-level definition with effects " ^ string_of_effect (effect_of checked_bind)) | LB_val (pat, bind) -> + check_duplicate_letbinding l pat orig_env; let inferred_bind = propagate_exp_effect (irule infer_exp orig_env (strip_exp bind)) in let tpat, env = bind_pat_no_guard orig_env (strip_pat pat) (typ_of inferred_bind) in if (BESet.is_empty (effect_set (effect_of inferred_bind)) || !opt_no_effects) |
