diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Makefile | 4 | ||||
| -rw-r--r-- | src/jib.lem | 140 | ||||
| -rw-r--r-- | src/jib/anf.ml | 4 | ||||
| -rw-r--r-- | src/jib/anf.mli | 4 | ||||
| -rw-r--r-- | src/jib/c_backend.ml | 149 | ||||
| -rw-r--r-- | src/jib/jib_compile.ml | 114 | ||||
| -rw-r--r-- | src/jib/jib_compile.mli | 9 | ||||
| -rw-r--r-- | src/jib/jib_ir.ml | 57 | ||||
| -rw-r--r-- | src/jib/jib_optimize.ml | 160 | ||||
| -rw-r--r-- | src/jib/jib_optimize.mli | 3 | ||||
| -rw-r--r-- | src/jib/jib_smt.ml | 44 | ||||
| -rw-r--r-- | src/jib/jib_smt_fuzz.ml | 2 | ||||
| -rw-r--r-- | src/jib/jib_ssa.ml | 2 | ||||
| -rw-r--r-- | src/jib/jib_util.ml | 144 | ||||
| -rw-r--r-- | src/sail.ml | 3 |
15 files changed, 585 insertions, 254 deletions
diff --git a/src/Makefile b/src/Makefile index a002d4f3..020d6813 100644 --- a/src/Makefile +++ b/src/Makefile @@ -74,9 +74,6 @@ full: sail lib doc ast.lem: ../language/sail.ott ott -sort false -generate_aux_rules true -o ast.lem -picky_multiple_parses true ../language/sail.ott -jib.lem: ../language/jib.ott ast.lem - ott -sort false -generate_aux_rules true -o jib.lem -picky_multiple_parses true ../language/jib.ott - ast.ml: ast.lem lem -ocaml ast.lem sed -i.bak -f ast.sed ast.ml @@ -137,7 +134,6 @@ clean: -rm -f ast.lem -rm -f ast.ml.bak -rm -f jib.ml - -rm -f jib.lem -rm -f jib.ml.bak -rm -f manifest.ml diff --git a/src/jib.lem b/src/jib.lem new file mode 100644 index 00000000..c8959d10 --- /dev/null +++ b/src/jib.lem @@ -0,0 +1,140 @@ +(* generated by Ott 0.29 from: ../language/jib.ott *) +open import Pervasives + + +open import Ast +open import Value2 + + + +type name = + | Name of id * nat + | Have_exception of nat + | Current_exception of nat + | Return of nat + + +type ctyp = (* C type *) + | CT_lint + | CT_fint of nat + | CT_constant of integer + | CT_lbits of bool + | CT_sbits of nat * bool + | CT_fbits of nat * bool + | CT_unit + | CT_bool + | CT_bit + | CT_string + | CT_real + | CT_tup of list ctyp + | CT_enum of id * list id + | CT_struct of id * list (uid * ctyp) + | CT_variant of id * list (uid * ctyp) + | CT_vector of bool * ctyp + | CT_list of ctyp + | CT_ref of ctyp + | CT_poly + +and uid = id * list ctyp + +type op = + | Bnot + | Bor + | Band + | List_hd + | List_tl + | Bit_to_bool + | Eq + | Neq + | Ilt + | Ilteq + | Igt + | Igteq + | Iadd + | Isub + | Unsigned of nat + | Signed of nat + | Bvnot + | Bvor + | Bvand + | Bvxor + | Bvadd + | Bvsub + | Bvaccess + | Concat + | Zero_extend of nat + | Sign_extend of nat + | Slice of nat + | Sslice of nat + | Set_slice + | Replicate of nat + + +type clexp = + | CL_id of name * ctyp + | CL_rmw of name * name * ctyp + | CL_field of clexp * uid + | CL_addr of clexp + | CL_tuple of clexp * nat + | CL_void + + +type cval = + | V_id of name * ctyp + | V_ref of name * ctyp + | V_lit of vl * ctyp + | V_struct of list (uid * cval) * ctyp + | V_ctor_kind of cval * id * list ctyp * ctyp + | V_ctor_unwrap of id * cval * list ctyp * ctyp + | V_tuple_member of cval * nat * nat + | V_call of op * list cval + | V_field of cval * uid + | V_poly of cval * ctyp + + +type iannot = nat * nat * nat + + +type instr_aux = + | I_decl of ctyp * name + | I_init of ctyp * name * cval + | I_jump of cval * string + | I_goto of string + | I_label of string + | I_funcall of clexp * bool * uid * list cval + | I_copy of clexp * cval + | I_clear of ctyp * name + | I_undefined of ctyp + | I_match_failure + | I_end of name + | I_if of cval * list instr * list instr * ctyp + | I_block of list instr + | I_try_block of list instr + | I_throw of cval + | I_comment of string + | I_raw of string + | I_return of cval + | I_reset of ctyp * name + | I_reinit of ctyp * name * cval + +and instr = + | I_aux of instr_aux * iannot + + +type ctype_def = (* C type definition *) + | CTD_enum of id * list id + | CTD_struct of id * list (uid * ctyp) + | CTD_variant of id * list (uid * ctyp) + + +type cdef = + | CDEF_reg_dec of id * ctyp * list instr + | CDEF_type of ctype_def + | CDEF_let of nat * list (id * ctyp) * list instr + | CDEF_spec of id * list ctyp * ctyp + | CDEF_fundef of id * maybe id * list id * list instr + | CDEF_startup of id * list instr + | CDEF_finish of id * list instr + + + diff --git a/src/jib/anf.ml b/src/jib/anf.ml index 52c4584c..6bc447c0 100644 --- a/src/jib/anf.ml +++ b/src/jib/anf.ml @@ -468,7 +468,7 @@ and pp_aval = function | AV_tuple avals -> parens (separate_map (comma ^^ space) pp_aval avals) | AV_ref (id, lvar) -> string "ref" ^^ space ^^ pp_lvar lvar (pp_id id) | AV_cval (cval, typ) -> - pp_annot typ (string (Jib_ir.string_of_cval cval |> Util.cyan |> Util.clear)) + pp_annot typ (string (string_of_cval cval |> Util.cyan |> Util.clear)) | AV_vector (avals, typ) -> pp_annot typ (string "[" ^^ separate_map (comma ^^ space) pp_aval avals ^^ string "]") | AV_list (avals, typ) -> @@ -482,7 +482,7 @@ let ae_lit lit typ = AE_val (AV_lit (lit, typ)) let is_dead_aexp (AE_aux (_, env, _)) = prove __POS__ env nc_false -let (gensym, _) = symbol_generator "anf" +let (gensym, reset_anf_counter) = symbol_generator "ga" let rec split_block l = function | [exp] -> [], exp diff --git a/src/jib/anf.mli b/src/jib/anf.mli index f28cf420..d01fe146 100644 --- a/src/jib/anf.mli +++ b/src/jib/anf.mli @@ -129,6 +129,10 @@ and 'a aval = | AV_record of ('a aval) Bindings.t * 'a | AV_cval of cval * 'a +(** When ANF translation has to introduce new bindings it uses a +counter to ensure uniqueness. This function resets that counter. *) +val reset_anf_counter : unit -> unit + (** {2 Functions for transforming ANF expressions} *) val aval_typ : typ aval -> typ diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index bf22c5d2..a97b06ef 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -61,8 +61,6 @@ open Anf module Big_int = Nat_big_num -let c_verbosity = ref 0 - let opt_static = ref false let opt_no_main = ref false let opt_memo_cache = ref false @@ -93,16 +91,16 @@ let optimize_fixed_bits = ref false let (gensym, _) = symbol_generator "cb" let ngensym () = name (gensym ()) -let c_debug str = - if !c_verbosity > 0 then prerr_endline (Lazy.force str) else () - let c_error ?loc:(l=Parse_ast.Unknown) message = raise (Reporting.err_general l ("\nC backend: " ^ message)) -let zencode_id = function - | Id_aux (Id str, l) -> Id_aux (Id (Util.zencode_string str), l) - | Id_aux (Operator str, l) -> Id_aux (Id (Util.zencode_string ("op " ^ str)), l) +let zencode_id id = Util.zencode_string (string_of_id id) +let zencode_uid (id, ctyps) = + match ctyps with + | [] -> Util.zencode_string (string_of_id id) + | _ -> Util.zencode_string (string_of_id id ^ "#" ^ Util.string_of_list "_" string_of_ctyp ctyps) + (**************************************************************************) (* 2. Converting sail types to C types *) (**************************************************************************) @@ -172,8 +170,8 @@ let rec ctyp_of_typ ctx typ = | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> CT_ref (ctyp_of_typ ctx typ) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> Bindings.bindings) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> Bindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> UBindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> UBindings.bindings) | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) @@ -221,7 +219,7 @@ let is_sbits_typ ctx typ = | CT_sbits _ -> true | _ -> false -let ctor_bindings = List.fold_left (fun map (id, ctyp) -> Bindings.add id ctyp map) Bindings.empty +let ctor_bindings = List.fold_left (fun map (id, ctyp) -> UBindings.add id ctyp map) UBindings.empty (**************************************************************************) (* 3. Optimization of primitives and literals *) @@ -407,8 +405,6 @@ let analyze_primop' ctx id args typ = let v_one = V_lit (VL_int (Big_int.of_int 1), CT_fint 64) in let v_int n = V_lit (VL_int (Big_int.of_int n), CT_fint 64) in - c_debug (lazy ("Analyzing primop " ^ extern ^ "(" ^ Util.string_of_list ", " (fun aval -> Pretty_print_sail.to_string (pp_aval aval)) args ^ ")")); - match extern, args with | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> begin match cval_ctyp v1 with @@ -541,7 +537,6 @@ let analyze_primop' ctx id args typ = AE_val (AV_cval (V_lit (VL_bool false, CT_bool), typ)) | _, _ -> - c_debug (lazy ("No optimization routine found")); no_change let analyze_primop ctx id args typ = @@ -549,7 +544,6 @@ let analyze_primop ctx id args typ = if !optimize_primops then try analyze_primop' ctx id args typ with | Failure str -> - (c_debug (lazy ("Analyze primop failed for id " ^ string_of_id id ^ " reason: " ^ str))); no_change else no_change @@ -702,7 +696,6 @@ let hoist_id () = let hoist_allocations recursive_functions = function | CDEF_fundef (function_id, _, _, _) as cdef when IdSet.mem function_id recursive_functions -> - c_debug (lazy (Printf.sprintf "skipping recursive function %s" (string_of_id function_id))); [cdef] | CDEF_fundef (function_id, heap_return, args, body) -> @@ -1028,13 +1021,19 @@ let optimize recursive_functions cdefs = (**************************************************************************) let sgen_id id = Util.zencode_string (string_of_id id) +let sgen_uid uid = zencode_uid uid let sgen_name id = string_of_name id let codegen_id id = string (sgen_id id) +let codegen_uid id = string (sgen_uid id) let sgen_function_id id = let str = Util.zencode_string (string_of_id id) in !opt_prefix ^ String.sub str 1 (String.length str - 1) +let sgen_function_uid uid = + let str = zencode_uid uid in + !opt_prefix ^ String.sub str 1 (String.length str - 1) + let codegen_function_id id = string (sgen_function_id id) let rec sgen_ctyp = function @@ -1095,30 +1094,40 @@ let sgen_mask n = else failwith "Tried to create a mask literal for a vector greater than 64 bits." +let sgen_value = function + | VL_bits ([], _) -> "UINT64_C(0)" + | VL_bits (bs, true) -> "UINT64_C(" ^ Sail2_values.show_bitlist bs ^ ")" + | VL_bits (bs, false) -> "UINT64_C(" ^ Sail2_values.show_bitlist (List.rev bs) ^ ")" + | VL_int i -> Big_int.to_string i ^ "l" + | VL_bool true -> "true" + | VL_bool false -> "false" + | VL_null -> "NULL" + | VL_unit -> "UNIT" + | VL_bit Sail2_values.B0 -> "UINT64_C(0)" + | VL_bit Sail2_values.B1 -> "UINT64_C(1)" + | VL_bit Sail2_values.BU -> failwith "Undefined bit found in value" + | VL_real str -> str + | VL_string str -> "\"" ^ str ^ "\"" + let rec sgen_cval = function | V_id (id, ctyp) -> string_of_name id | V_ref (id, _) -> "&" ^ string_of_name id - | V_lit (vl, ctyp) -> string_of_value vl + | V_lit (vl, ctyp) -> sgen_value vl | V_call (op, cvals) -> sgen_call op cvals | V_field (f, field) -> - Printf.sprintf "%s.%s" (sgen_cval f) field + Printf.sprintf "%s.%s" (sgen_cval f) (sgen_uid field) | V_tuple_member (f, _, n) -> Printf.sprintf "%s.ztup%d" (sgen_cval f) n - | V_ctor_kind (f, ctor, [], _) -> - sgen_cval f ^ ".kind" - ^ " != Kind_" ^ Util.zencode_string (string_of_id ctor) | V_ctor_kind (f, ctor, unifiers, _) -> sgen_cval f ^ ".kind" - ^ " != Kind_" ^ Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) - | V_ctor_unwrap (ctor, f, [], _) -> - Printf.sprintf "%s.%s" (sgen_cval f) (Util.zencode_string (string_of_id ctor)) + ^ " != Kind_" ^ zencode_uid (ctor, unifiers) | V_struct (fields, _) -> Printf.sprintf "{%s}" - (Util.string_of_list ", " (fun (field, cval) -> string_of_id field ^ " = " ^ sgen_cval cval) fields) + (Util.string_of_list ", " (fun (field, cval) -> zencode_uid field ^ " = " ^ sgen_cval cval) fields) | V_ctor_unwrap (ctor, f, unifiers, _) -> Printf.sprintf "%s.%s" (sgen_cval f) - (Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)) + (sgen_uid (ctor, unifiers)) | V_poly (f, _) -> sgen_cval f and sgen_call op cvals = @@ -1299,7 +1308,7 @@ let rec sgen_clexp = function | CL_id (Current_exception _, _) -> "current_exception" | CL_id (Return _, _) -> assert false | CL_id (Name (id, _), _) -> "&" ^ sgen_id id - | CL_field (clexp, field) -> "&((" ^ sgen_clexp clexp ^ ")->" ^ Util.zencode_string field ^ ")" + | CL_field (clexp, field) -> "&((" ^ sgen_clexp clexp ^ ")->" ^ zencode_uid field ^ ")" | CL_tuple (clexp, n) -> "&((" ^ sgen_clexp clexp ^ ")->ztup" ^ string_of_int n ^ ")" | CL_addr clexp -> "(*(" ^ sgen_clexp clexp ^ "))" | CL_void -> assert false @@ -1310,7 +1319,7 @@ let rec sgen_clexp_pure = function | CL_id (Current_exception _, _) -> "current_exception" | CL_id (Return _, _) -> assert false | CL_id (Name (id, _), _) -> sgen_id id - | CL_field (clexp, field) -> sgen_clexp_pure clexp ^ "." ^ Util.zencode_string field + | CL_field (clexp, field) -> sgen_clexp_pure clexp ^ "." ^ zencode_uid field | CL_tuple (clexp, n) -> sgen_clexp_pure clexp ^ ".ztup" ^ string_of_int n | CL_addr clexp -> "(*(" ^ sgen_clexp_pure clexp ^ "))" | CL_void -> assert false @@ -1394,14 +1403,14 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = | I_funcall (x, extern, f, args) -> let c_args = Util.string_of_list ", " sgen_cval args in let ctyp = clexp_ctyp x in - let is_extern = Env.is_extern f ctx.tc_env "c" || extern in + let is_extern = Env.is_extern (fst f) ctx.tc_env "c" || extern in let fname = - if Env.is_extern f ctx.tc_env "c" then - Env.get_extern f ctx.tc_env "c" + if Env.is_extern (fst f) ctx.tc_env "c" then + Env.get_extern (fst f) ctx.tc_env "c" else if extern then - string_of_id f + string_of_id (fst f) else - sgen_function_id f + sgen_function_uid f in let fname = match fname, ctyp with @@ -1506,9 +1515,9 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = ^ Util.string_of_list ", " (fun x -> x) inits ^ " };"] @ prev | CT_struct (id, ctors) when is_stack_ctyp ctyp -> let gs = ngensym () in - let fold (inits, prev) (id, ctyp) = + let fold (inits, prev) (uid, ctyp) = let init, prev' = codegen_exn_return ctyp in - Printf.sprintf ".%s = %s" (sgen_id id) init :: inits, prev @ prev' + Printf.sprintf ".%s = %s" (sgen_uid uid) init :: inits, prev @ prev' in let inits, prev = List.fold_left fold ([], []) ctors in sgen_name gs, @@ -1560,36 +1569,34 @@ let codegen_type_def ctx = function | CTD_struct (id, ctors) -> let struct_ctyp = CT_struct (id, ctors) in - c_debug (lazy (Printf.sprintf "Generating struct for %s" (full_string_of_ctyp struct_ctyp))); - (* Generate a set_T function for every struct T *) let codegen_set (id, ctyp) = if is_stack_ctyp ctyp then - string (Printf.sprintf "rop->%s = op.%s;" (sgen_id id) (sgen_id id)) + string (Printf.sprintf "rop->%s = op.%s;" (sgen_uid id) (sgen_uid id)) else - string (Printf.sprintf "COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) + string (Printf.sprintf "COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_uid id) (sgen_uid id)) in let codegen_setter id ctors = string (let n = sgen_id id in Printf.sprintf "static void COPY(%s)(struct %s *rop, const struct %s op)" n n n) ^^ space ^^ surround 2 0 lbrace - (separate_map hardline codegen_set (Bindings.bindings ctors)) + (separate_map hardline codegen_set (UBindings.bindings ctors)) rbrace in (* Generate an init/clear_T function for every struct T *) let codegen_field_init f (id, ctyp) = if not (is_stack_ctyp ctyp) then - [string (Printf.sprintf "%s(%s)(&op->%s);" f (sgen_ctyp_name ctyp) (sgen_id id))] + [string (Printf.sprintf "%s(%s)(&op->%s);" f (sgen_ctyp_name ctyp) (sgen_uid id))] else [] in let codegen_init f id ctors = string (let n = sgen_id id in Printf.sprintf "static void %s(%s)(struct %s *op)" f n n) ^^ space ^^ surround 2 0 lbrace - (separate hardline (Bindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat)) + (separate hardline (UBindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat)) rbrace in let codegen_eq = let codegen_eq_test (id, ctyp) = - string (Printf.sprintf "EQUAL(%s)(op1.%s, op2.%s)" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) + string (Printf.sprintf "EQUAL(%s)(op1.%s, op2.%s)" (sgen_ctyp_name ctyp) (sgen_uid id) (sgen_uid id)) in string (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2)" (sgen_id id) (sgen_id id) (sgen_id id)) ^^ space @@ -1601,7 +1608,7 @@ let codegen_type_def ctx = function in (* Generate the struct and add the generated functions *) let codegen_ctor (id, ctyp) = - string (sgen_ctyp ctyp) ^^ space ^^ codegen_id id + string (sgen_ctyp ctyp) ^^ space ^^ codegen_uid id in string (Printf.sprintf "// struct %s" (string_of_id id)) ^^ hardline ^^ string "struct" ^^ space ^^ codegen_id id ^^ space @@ -1623,17 +1630,17 @@ let codegen_type_def ctx = function | CTD_variant (id, tus) -> let codegen_tu (ctor_id, ctyp) = - separate space [string "struct"; lbrace; string (sgen_ctyp ctyp); codegen_id ctor_id ^^ semi; rbrace] + separate space [string "struct"; lbrace; string (sgen_ctyp ctyp); codegen_uid ctor_id ^^ semi; rbrace] in (* Create an if, else if, ... block that does something for each constructor *) let rec each_ctor v f = function | [] -> string "{}" | [(ctor_id, ctyp)] -> - string (Printf.sprintf "if (%skind == Kind_%s)" v (sgen_id ctor_id)) ^^ lbrace ^^ hardline + string (Printf.sprintf "if (%skind == Kind_%s)" v (sgen_uid ctor_id)) ^^ lbrace ^^ hardline ^^ jump 0 2 (f ctor_id ctyp) ^^ hardline ^^ rbrace | (ctor_id, ctyp) :: ctors -> - string (Printf.sprintf "if (%skind == Kind_%s) " v (sgen_id ctor_id)) ^^ lbrace ^^ hardline + string (Printf.sprintf "if (%skind == Kind_%s) " v (sgen_uid ctor_id)) ^^ lbrace ^^ hardline ^^ jump 0 2 (f ctor_id ctyp) ^^ hardline ^^ rbrace ^^ string " else " ^^ each_ctor v f ctors in @@ -1643,9 +1650,9 @@ let codegen_type_def ctx = function string (Printf.sprintf "static void CREATE(%s)(struct %s *op)" n n) ^^ hardline ^^ surround 2 0 lbrace - (string (Printf.sprintf "op->kind = Kind_%s;" (sgen_id ctor_id)) ^^ hardline + (string (Printf.sprintf "op->kind = Kind_%s;" (sgen_uid ctor_id)) ^^ hardline ^^ if not (is_stack_ctyp ctyp) then - string (Printf.sprintf "CREATE(%s)(&op->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) + string (Printf.sprintf "CREATE(%s)(&op->%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id)) else empty) rbrace in @@ -1657,7 +1664,7 @@ let codegen_type_def ctx = function if is_stack_ctyp ctyp then string (Printf.sprintf "/* do nothing */") else - string (Printf.sprintf "KILL(%s)(&%s->%s);" (sgen_ctyp_name ctyp) v (sgen_id ctor_id)) + string (Printf.sprintf "KILL(%s)(&%s->%s);" (sgen_ctyp_name ctyp) v (sgen_uid ctor_id)) in let codegen_clear = let n = sgen_id id in @@ -1676,16 +1683,16 @@ let codegen_type_def ctx = function in Printf.sprintf "%s op" (sgen_ctyp ctyp), empty, empty in - string (Printf.sprintf "static void %s(%sstruct %s *rop, %s)" (sgen_function_id ctor_id) (extra_params ()) (sgen_id id) ctor_args) ^^ hardline + string (Printf.sprintf "static void %s(%sstruct %s *rop, %s)" (sgen_function_uid ctor_id) (extra_params ()) (sgen_id id) ctor_args) ^^ hardline ^^ surround 2 0 lbrace (tuple ^^ each_ctor "rop->" (clear_field "rop") tus ^^ hardline - ^^ string ("rop->kind = Kind_" ^ sgen_id ctor_id) ^^ semi ^^ hardline + ^^ string ("rop->kind = Kind_" ^ sgen_uid ctor_id) ^^ semi ^^ hardline ^^ if is_stack_ctyp ctyp then - string (Printf.sprintf "rop->%s = op;" (sgen_id ctor_id)) + string (Printf.sprintf "rop->%s = op;" (sgen_uid ctor_id)) else - string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) ^^ hardline - ^^ string (Printf.sprintf "COPY(%s)(&rop->%s, op);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) ^^ hardline + string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id)) ^^ hardline + ^^ string (Printf.sprintf "COPY(%s)(&rop->%s, op);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id)) ^^ hardline ^^ tuple_cleanup) rbrace in @@ -1693,10 +1700,10 @@ let codegen_type_def ctx = function let n = sgen_id id in let set_field ctor_id ctyp = if is_stack_ctyp ctyp then - string (Printf.sprintf "rop->%s = op.%s;" (sgen_id ctor_id) (sgen_id ctor_id)) + string (Printf.sprintf "rop->%s = op.%s;" (sgen_uid ctor_id) (sgen_uid ctor_id)) else - string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) - ^^ string (Printf.sprintf " COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id)) + string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id)) + ^^ string (Printf.sprintf " COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id) (sgen_uid ctor_id)) in string (Printf.sprintf "static void COPY(%s)(struct %s *rop, struct %s op)" n n n) ^^ hardline ^^ surround 2 0 lbrace @@ -1709,12 +1716,12 @@ let codegen_type_def ctx = function in let codegen_eq = let codegen_eq_test ctor_id ctyp = - string (Printf.sprintf "return EQUAL(%s)(op1.%s, op2.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id)) + string (Printf.sprintf "return EQUAL(%s)(op1.%s, op2.%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id) (sgen_uid ctor_id)) in let rec codegen_eq_tests = function | [] -> string "return false;" | (ctor_id, ctyp) :: ctors -> - string (Printf.sprintf "if (op1.kind == Kind_%s && op2.kind == Kind_%s) " (sgen_id ctor_id) (sgen_id ctor_id)) ^^ lbrace ^^ hardline + string (Printf.sprintf "if (op1.kind == Kind_%s && op2.kind == Kind_%s) " (sgen_uid ctor_id) (sgen_uid ctor_id)) ^^ lbrace ^^ hardline ^^ jump 0 2 (codegen_eq_test ctor_id ctyp) ^^ hardline ^^ rbrace ^^ string " else " ^^ codegen_eq_tests ctors in @@ -1726,7 +1733,7 @@ let codegen_type_def ctx = function ^^ string "enum" ^^ space ^^ string ("kind_" ^ sgen_id id) ^^ space ^^ separate space [ lbrace; - separate_map (comma ^^ space) (fun id -> string ("Kind_" ^ sgen_id id)) (List.map fst tus); + separate_map (comma ^^ space) (fun id -> string ("Kind_" ^ sgen_uid id)) (List.map fst tus); rbrace ^^ semi ] ^^ twice hardline ^^ string "struct" ^^ space ^^ codegen_id id ^^ space @@ -1784,12 +1791,12 @@ let codegen_tup ctx ctyps = empty else begin - let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields) - (0, Bindings.empty) + let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, UBindings.add (mk_id ("tup" ^ string_of_int n), []) ctyp fields) + (0, UBindings.empty) ctyps in generated := IdSet.add id !generated; - codegen_type_def ctx (CTD_struct (id, Bindings.bindings fields)) ^^ twice hardline + codegen_type_def ctx (CTD_struct (id, UBindings.bindings fields)) ^^ twice hardline end let codegen_node id ctyp = @@ -1997,15 +2004,12 @@ let codegen_def' ctx = function string (Printf.sprintf "%svoid %s(%s%s *rop, %s);" static (sgen_function_id id) (extra_params ()) (sgen_ctyp ret_ctyp) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) | CDEF_fundef (id, ret_arg, args, instrs) as def -> - (* Extract type information about the function from the environment. *) - let quant, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in - let arg_typs, ret_typ = match fn_typ with - | Typ_fn (arg_typs, ret_typ, _) -> arg_typs, ret_typ - | _ -> assert false + let arg_ctyps, ret_ctyp = match Bindings.find_opt id ctx.valspecs with + | Some vs -> vs + | None -> + c_error ~loc:(id_loc id) ("No valspec found for " ^ string_of_id id) in - let ctx' = { ctx with local_env = add_typquant (id_loc id) quant ctx.local_env } in - let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ in - + (* Check that the function has the correct arity at this point. *) if List.length arg_ctyps <> List.length args then c_error ~loc:(id_loc id) ("function arguments " @@ -2175,7 +2179,6 @@ let jib_of_ast env ast = let compile_ast env output_chan c_includes ast = try - c_debug (lazy (Util.log_line __MODULE__ __LINE__ "Identifying recursive functions")); let recursive_functions = Spec_analysis.top_sort_defs ast |> get_recursive_functions in let cdefs, ctx = jib_of_ast env ast in diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index 1a0411aa..b178f0e2 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -57,7 +57,6 @@ open Value2 open Anf -let opt_debug_flow_graphs = ref false let opt_memo_cache = ref false let optimize_aarch64_fast_struct = ref false @@ -138,7 +137,7 @@ let iblock1 = function | [instr] -> instr | instrs -> iblock instrs -let ctor_bindings = List.fold_left (fun map (id, ctyp) -> Bindings.add id ctyp map) Bindings.empty +let ctor_bindings = List.fold_left (fun map (uid, ctyp) -> UBindings.add uid ctyp map) UBindings.empty (** The context type contains two type-checking environments. ctx.local_env contains the closest typechecking @@ -148,9 +147,10 @@ let ctor_bindings = List.fold_left (fun map (id, ctyp) -> Bindings.add id ctyp m in ctx.locals, so we know when their type changes due to flow typing. *) type ctx = - { records : (ctyp Bindings.t) Bindings.t; + { records : (ctyp UBindings.t) Bindings.t; enums : IdSet.t Bindings.t; - variants : (ctyp Bindings.t) Bindings.t; + variants : (ctyp UBindings.t) Bindings.t; + valspecs : (ctyp list * ctyp) Bindings.t; tc_env : Env.t; local_env : Env.t; locals : (mut * ctyp) Bindings.t; @@ -168,6 +168,7 @@ let initial_ctx ~convert_typ:convert_typ ~optimize_anf:optimize_anf env = { records = Bindings.empty; enums = Bindings.empty; variants = Bindings.empty; + valspecs = Bindings.empty; tc_env = env; local_env = env; locals = Bindings.empty; @@ -269,7 +270,7 @@ let rec compile_aval l ctx = function let compile_fields (id, aval) = let field_setup, cval, field_cleanup = compile_aval l ctx aval in field_setup, - (id, cval), + ((id, []), cval), field_cleanup in let field_triples = List.map compile_fields (Bindings.bindings fields) in @@ -286,7 +287,7 @@ let rec compile_aval l ctx = function let compile_fields (id, aval) = let field_setup, cval, field_cleanup = compile_aval l ctx aval in field_setup - @ [icopy l (CL_field (CL_id (gs, ctyp), string_of_id id)) cval] + @ [icopy l (CL_field (CL_id (gs, ctyp), (id, []))) cval] @ field_cleanup in [idecl ctyp gs] @@ -302,7 +303,7 @@ let rec compile_aval l ctx = function | _ -> let gs = ngensym () in [idecl vector_ctyp gs; - iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init") [V_lit (VL_int Big_int.zero, CT_fint 64)]], + iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init", []) [V_lit (VL_int Big_int.zero, CT_fint 64)]], V_id (gs, vector_ctyp), [iclear vector_ctyp gs] end @@ -333,7 +334,7 @@ let rec compile_aval l ctx = function let gs = ngensym () in [iinit (CT_lbits true) gs (V_lit (first_chunk, CT_fbits (len mod 64, true)))] @ List.map (fun chunk -> ifuncall (CL_id (gs, CT_lbits true)) - (mk_id "append_64") + (mk_id "append_64", []) [V_id (gs, CT_lbits true); V_lit (chunk, CT_fbits (64, true))]) chunks, V_id (gs, CT_lbits true), [iclear (CT_lbits true) gs] @@ -381,12 +382,12 @@ let rec compile_aval l ctx = function let setup, cval, cleanup = compile_aval l ctx aval in setup @ [iextern (CL_id (gs, vector_ctyp)) - (mk_id "internal_vector_update") + (mk_id "internal_vector_update", []) [V_id (gs, vector_ctyp); V_lit (VL_int (Big_int.of_int i), CT_fint 64); cval]] @ cleanup in [idecl vector_ctyp gs; - iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init") [V_lit (VL_int (Big_int.of_int len), CT_fint 64)]] + iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init", []) [V_lit (VL_int (Big_int.of_int len), CT_fint 64)]] @ List.concat (List.mapi aval_set (if direction then List.rev avals else avals)), V_id (gs, vector_ctyp), [iclear vector_ctyp gs] @@ -402,7 +403,7 @@ let rec compile_aval l ctx = function let gs = ngensym () in let mk_cons aval = let setup, cval, cleanup = compile_aval l ctx aval in - setup @ [ifuncall (CL_id (gs, CT_list ctyp)) (mk_id ("cons#" ^ string_of_ctyp ctyp)) [cval; V_id (gs, CT_list ctyp)]] @ cleanup + setup @ [ifuncall (CL_id (gs, CT_list ctyp)) (mk_id ("cons#" ^ string_of_ctyp ctyp), []) [cval; V_id (gs, CT_list ctyp)]] @ cleanup in [idecl (CT_list ctyp) gs] @ List.concat (List.map mk_cons (List.rev avals)), @@ -439,8 +440,8 @@ let optimize_call l ctx clexp id args arg_ctyps ret_ctyp = iclear ret_ctyp gs] @ !cleanup in - if not ctx.specialize_calls && Env.is_extern id ctx.tc_env "c" then - let extern = Env.get_extern id ctx.tc_env "c" in + if not ctx.specialize_calls && Env.is_extern (fst id) ctx.tc_env "c" then + let extern = Env.get_extern (fst id) ctx.tc_env "c" in begin match extern, List.map cval_ctyp args, clexp_ctyp clexp with | "slice", [CT_fbits _; CT_lint; _], CT_fbits (n, _) -> let start = ngensym () in @@ -487,7 +488,7 @@ let compile_funcall l ctx id args = List.rev !setup, begin fun clexp -> - iblock1 (optimize_call l ctx clexp id setup_args arg_ctyps ret_ctyp) + iblock1 (optimize_call l ctx clexp (id, []) setup_args arg_ctyps ret_ctyp) end, !cleanup @@ -547,7 +548,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = | AP_app (ctor, apat, variant_typ) -> begin match ctyp with | CT_variant (_, ctors) -> - let ctor_ctyp = Bindings.find ctor (ctor_bindings ctors) in + let ctor_ctyp = UBindings.find (ctor, []) (ctor_bindings ctors) in let pat_ctyp = apat_ctyp ctx apat in (* These should really be the same, something has gone wrong if they are not. *) if ctyp_equal ctor_ctyp (ctyp_of_typ ctx variant_typ) then @@ -569,7 +570,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = raise (Reporting.err_general l (Printf.sprintf "Variant constructor %s : %s matching against non-variant type %s : %s" (string_of_id ctor) (string_of_typ variant_typ) - (Jib_ir.string_of_cval cval) + (string_of_cval cval) (string_of_ctyp ctyp))) end @@ -716,14 +717,14 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = | AE_record_update (aval, fields, typ) -> let ctyp = ctyp_of_typ ctx typ in let ctors = match ctyp with - | CT_struct (_, ctors) -> List.fold_left (fun m (k, v) -> Bindings.add k v m) Bindings.empty ctors + | CT_struct (_, ctors) -> List.fold_left (fun m (k, v) -> UBindings.add k v m) UBindings.empty ctors | _ -> raise (Reporting.err_general l "Cannot perform record update for non-record type") in let gs = ngensym () in let compile_fields (id, aval) = let field_setup, cval, field_cleanup = compile_aval l ctx aval in field_setup - @ [icopy l (CL_field (CL_id (gs, ctyp), string_of_id id)) cval] + @ [icopy l (CL_field (CL_id (gs, ctyp), (id, []))) cval] @ field_cleanup in let setup, cval, cleanup = compile_aval l ctx aval in @@ -769,7 +770,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let compile_fields (field_id, aval) = let field_setup, cval, field_cleanup = compile_aval l ctx aval in field_setup - @ [icopy l (CL_field (CL_id (name id, ctyp_of_typ ctx typ), string_of_id field_id)) cval] + @ [icopy l (CL_field (CL_id (name id, ctyp_of_typ ctx typ), (field_id, []))) cval] @ field_cleanup in List.concat (List.map compile_fields (Bindings.bindings fields)), @@ -877,7 +878,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let setup, cval, cleanup = compile_aval l ctx aval in let ctyp = match cval_ctyp cval with | CT_struct (struct_id, fields) -> - begin match Util.assoc_compare_opt Id.compare id fields with + begin match Util.assoc_compare_opt UId.compare (id, []) fields with | Some ctyp -> ctyp | None -> raise (Reporting.err_unreachable l __POS__ @@ -893,12 +894,8 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = else [], ctyp in - let field_str = match unifiers with - | [] -> Util.zencode_string (string_of_id id) - | _ -> Util.zencode_string (string_of_id id ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) - in setup, - (fun clexp -> icopy l clexp (V_field (cval, field_str))), + (fun clexp -> icopy l clexp (V_field (cval, (id, unifiers)))), cleanup | AE_for (loop_var, loop_from, loop_to, loop_step, Ord_aux (ord, _), body) -> @@ -976,9 +973,9 @@ let compile_type_def ctx (TD_aux (type_def, (l, _))) = | TD_record (id, typq, ctors, _) -> let record_ctx = { ctx with local_env = add_typquant l typq ctx.local_env } in let ctors = - List.fold_left (fun ctors (typ, id) -> Bindings.add id (fast_int (ctyp_of_typ record_ctx typ)) ctors) Bindings.empty ctors + List.fold_left (fun ctors (typ, id) -> UBindings.add (id, []) (fast_int (ctyp_of_typ record_ctx typ)) ctors) UBindings.empty ctors in - CTD_struct (id, Bindings.bindings ctors), + CTD_struct (id, UBindings.bindings ctors), { ctx with records = Bindings.add id ctors ctx.records } | TD_variant (id, typq, tus, _) -> @@ -987,8 +984,8 @@ let compile_type_def ctx (TD_aux (type_def, (l, _))) = let ctx = { ctx with local_env = add_typquant (id_loc id) typq ctx.local_env } in ctyp_of_typ ctx typ, id in - let ctus = List.fold_left (fun ctus (ctyp, id) -> Bindings.add id ctyp ctus) Bindings.empty (List.map compile_tu tus) in - CTD_variant (id, Bindings.bindings ctus), + let ctus = List.fold_left (fun ctus (ctyp, id) -> UBindings.add (id, []) ctyp ctus) UBindings.empty (List.map compile_tu tus) in + CTD_variant (id, UBindings.bindings ctus), { ctx with variants = Bindings.add id ctus ctx.variants } (* Will be re-written before here, see bitfield.ml *) @@ -1048,7 +1045,7 @@ let fix_exception_block ?return:(return=None) ctx instrs = @ [igoto end_block_label] @ rewrite_exception (historic @ before) after | before, (I_aux (I_funcall (x, _, f, args), _) as funcall) :: after -> - let effects = match Env.get_val_spec f ctx.tc_env with + let effects = match Env.get_val_spec (fst f) ctx.tc_env with | _, Typ_aux (Typ_fn (_, _, effects), _) -> effects | exception (Type_error _) -> no_effect (* nullary union constructor, so no val spec *) | _ -> assert false (* valspec must have function type *) @@ -1242,24 +1239,11 @@ let compile_funcl ctx id pat guard exp = let instrs = fix_early_return (CL_id (return, ret_ctyp)) instrs in let instrs = fix_exception ~return:(Some ret_ctyp) ctx instrs in - if !opt_debug_flow_graphs then - begin - let instrs = Jib_optimize.(instrs |> optimize_unit |> flatten_instrs) in - let root, _, cfg = Jib_ssa.control_flow_graph instrs in - let idom = Jib_ssa.immediate_dominators cfg root in - let _, cfg = Jib_ssa.ssa instrs in - let out_chan = open_out (Util.zencode_string (string_of_id id) ^ ".gv") in - Jib_ssa.make_dot out_chan cfg; - close_out out_chan; - let out_chan = open_out (Util.zencode_string (string_of_id id) ^ ".dom.gv") in - Jib_ssa.make_dominators_dot out_chan idom cfg; - close_out out_chan; - end; - [CDEF_fundef (id, None, List.map fst compiled_args, instrs)], orig_ctx (** Compile a Sail toplevel definition into an IR definition **) let rec compile_def n total ctx def = + reset_anf_counter (); reset_gensym_counter (); match def with | DEF_fundef (FD_aux (FD_function (_, _, _, [FCL_aux (FCL_Funcl (id, _), _)]), _)) @@ -1314,7 +1298,8 @@ and compile_def' n total ctx = function in let ctx' = { ctx with local_env = add_typquant (id_loc id) quant ctx.local_env } in let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ in - [CDEF_spec (id, arg_ctyps, ret_ctyp)], ctx + [CDEF_spec (id, arg_ctyps, ret_ctyp)], + { ctx with valspecs = Bindings.add id (arg_ctyps, ret_ctyp) ctx.valspecs } | DEF_fundef (FD_aux (FD_function (_, _, _, [FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (pat, exp), _)), _)]), _)) -> Util.progress "Compiling " (string_of_id id) n total; @@ -1389,7 +1374,7 @@ and compile_def' n total ctx = function raise (Reporting.err_general Parse_ast.Unknown ("Could not compile:\n" ^ Pretty_print_sail.to_string (Pretty_print_sail.doc_def def))) let rec specialize_variants ctx prior = - let unifications = ref (Bindings.empty) in + let unifications = ref (UBindings.empty) in let fix_variant_ctyp var_id new_ctors = function | CT_variant (id, ctors) when Id.compare id var_id = 0 -> CT_variant (id, new_ctors) @@ -1403,12 +1388,14 @@ let rec specialize_variants ctx prior = (* specialize_constructor is called on all instructions when we find a constructor in a union type that is polymorphic. It's job is to record all uses of that constructor so we can monomorphise it. *) - let specialize_constructor ctx ctor_id ctyp = function - | I_aux (I_funcall (clexp, extern, id, [cval]), ((_, l) as aux)) as instr when Id.compare id ctor_id = 0 -> + let specialize_constructor ctx (ctor_id, existing_unifiers) ctyp = + assert (existing_unifiers = []); + function + | I_aux (I_funcall (clexp, extern, id, [cval]), ((_, l) as aux)) as instr when UId.compare id (ctor_id, []) = 0 -> (* Work out how each call to a constructor in instantiated and add that to unifications *) let unifiers = List.map ctyp_suprema (ctyp_unify ctyp (cval_ctyp cval)) in - let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> string_of_ctyp ctyp) unifiers) in - unifications := Bindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications; + let mono_id = (ctor_id, unifiers) in + unifications := UBindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications; (* We need to cast each cval to it's ctyp_suprema in order to put it in the most general constructor *) let setup, cval, cleanup = @@ -1432,13 +1419,12 @@ let rec specialize_variants ctx prior = mk_funcall (I_aux (I_funcall (clexp, extern, mono_id, [cval]), aux)) - | I_aux (I_funcall (clexp, extern, id, cvals), ((_, l) as aux)) as instr when Id.compare id ctor_id = 0 -> + | I_aux (I_funcall (clexp, extern, id, cvals), ((_, l) as aux)) as instr when UId.compare id (ctor_id, []) = 0 -> Reporting.unreachable l __POS__ "Multiple argument constructor found" - (* We have to be careful this is the only place where an F_ctor_kind can appear before calling specialize variants *) + (* We have to be careful this is the only place where an V_ctor_kind can appear before calling specialize variants *) | I_aux (I_jump (V_ctor_kind (_, id, unifiers, pat_ctyp), _), _) as instr when Id.compare id ctor_id = 0 -> - let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> string_of_ctyp ctyp) unifiers) in - unifications := Bindings.add mono_id (ctyp_suprema pat_ctyp) !unifications; + unifications := UBindings.add (ctor_id, unifiers) (ctyp_suprema pat_ctyp) !unifications; instr | instr -> instr @@ -1446,12 +1432,14 @@ let rec specialize_variants ctx prior = (* specialize_field performs the same job as specialize_constructor, but for struct fields rather than union constructors. *) - let specialize_field ctx field_id ctyp = function - | I_aux (I_copy (CL_field (clexp, field_str), cval), aux) when string_of_id field_id = field_str -> + let specialize_field ctx (field_id, existing_unifiers) ctyp = + assert (existing_unifiers = []); + function + | I_aux (I_copy (CL_field (clexp, field), cval), aux) when UId.compare (field_id, []) field = 0 -> let unifiers = List.map ctyp_suprema (ctyp_unify ctyp (cval_ctyp cval)) in - let mono_id = append_id field_id ("_" ^ Util.string_of_list "_" (fun ctyp -> string_of_ctyp ctyp) unifiers) in - unifications := Bindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications; - I_aux (I_copy (CL_field (clexp, string_of_id mono_id), cval), aux) + let mono_id = (field_id, unifiers) in + unifications := UBindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications; + I_aux (I_copy (CL_field (clexp, mono_id), cval), aux) | instr -> instr in @@ -1467,12 +1455,12 @@ let rec specialize_variants ctx prior = in let monomorphic_ctors = List.filter (fun (_, ctyp) -> not (is_polymorphic ctyp)) ctors in - let specialized_ctors = Bindings.bindings !unifications in + let specialized_ctors = UBindings.bindings !unifications in let new_ctors = monomorphic_ctors @ specialized_ctors in let ctx = { ctx with variants = Bindings.add var_id - (List.fold_left (fun m (id, ctyp) -> Bindings.add id ctyp m) !unifications monomorphic_ctors) + (List.fold_left (fun m (uid, ctyp) -> UBindings.add uid ctyp m) !unifications monomorphic_ctors) ctx.variants } in @@ -1491,12 +1479,12 @@ let rec specialize_variants ctx prior = in let monomorphic_fields = List.filter (fun (_, ctyp) -> not (is_polymorphic ctyp)) fields in - let specialized_fields = Bindings.bindings !unifications in + let specialized_fields = UBindings.bindings !unifications in let new_fields = monomorphic_fields @ specialized_fields in let ctx = { ctx with records = Bindings.add struct_id - (List.fold_left (fun m (id, ctyp) -> Bindings.add id ctyp m) !unifications monomorphic_fields) + (List.fold_left (fun m (uid, ctyp) -> UBindings.add uid ctyp m) !unifications monomorphic_fields) ctx.records } in diff --git a/src/jib/jib_compile.mli b/src/jib/jib_compile.mli index d4e67daa..273e9e03 100644 --- a/src/jib/jib_compile.mli +++ b/src/jib/jib_compile.mli @@ -56,10 +56,6 @@ open Ast_util open Jib open Type_check -(** Output a dataflow graph for each generated function in Graphviz - (dot) format. *) -val opt_debug_flow_graphs : bool ref - (** This forces all integer struct fields to be represented as int64_t. Specifically intended for the various TLB structs in the ARM v8.5 spec. *) @@ -73,9 +69,10 @@ val optimize_aarch64_fast_struct : bool ref well as a function that optimizes ANF expressions (which can just be the identity function) *) type ctx = - { records : (ctyp Bindings.t) Bindings.t; + { records : (ctyp Jib_util.UBindings.t) Bindings.t; enums : IdSet.t Bindings.t; - variants : (ctyp Bindings.t) Bindings.t; + variants : (ctyp Jib_util.UBindings.t) Bindings.t; + valspecs : (ctyp list * ctyp) Bindings.t; tc_env : Env.t; local_env : Env.t; locals : (mut * ctyp) Bindings.t; diff --git a/src/jib/jib_ir.ml b/src/jib/jib_ir.ml index 1449b070..df60db1c 100644 --- a/src/jib/jib_ir.ml +++ b/src/jib/jib_ir.ml @@ -70,49 +70,9 @@ let string_of_name = | Current_exception n -> "current_exception" ^ ssa_num n -let string_of_value = function - | VL_bits ([], _) -> "empty" - | VL_bits (bs, true) -> Sail2_values.show_bitlist bs - | VL_bits (bs, false) -> Sail2_values.show_bitlist (List.rev bs) - | VL_int i -> Big_int.to_string i - | VL_bool true -> "true" - | VL_bool false -> "false" - | VL_null -> "NULL" - | VL_unit -> "()" - | VL_bit Sail2_values.B0 -> "bitzero" - | VL_bit Sail2_values.B1 -> "bitone" - | VL_bit Sail2_values.BU -> "bitundef" - | VL_real str -> str - | VL_string str -> "\"" ^ str ^ "\"" - -let rec string_of_cval = function - | V_id (id, ctyp) -> string_of_name id - | V_ref (id, _) -> "&" ^ string_of_name id - | V_lit (vl, ctyp) -> string_of_value vl - | V_call (op, cvals) -> - Printf.sprintf "%s(%s)" (string_of_op op) (Util.string_of_list ", " string_of_cval cvals) - | V_field (f, field) -> - Printf.sprintf "%s.%s" (string_of_cval f) field - | V_tuple_member (f, _, n) -> - Printf.sprintf "%s.ztup%d" (string_of_cval f) n - | V_ctor_kind (f, ctor, [], _) -> - string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor) - | V_ctor_kind (f, ctor, unifiers, _) -> - string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) - | V_ctor_unwrap (ctor, f, [], _) -> - Printf.sprintf "%s as %s" (string_of_cval f) (string_of_id ctor) - | V_ctor_unwrap (ctor, f, unifiers, _) -> - Printf.sprintf "%s as %s" - (string_of_cval f) - (Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)) - | V_struct (fields, _) -> - Printf.sprintf "{%s}" - (Util.string_of_list ", " (fun (field, cval) -> zencode_id field ^ " = " ^ string_of_cval cval) fields) - | V_poly (f, _) -> string_of_cval f - let rec string_of_clexp = function | CL_id (id, ctyp) -> string_of_name id - | CL_field (clexp, field) -> string_of_clexp clexp ^ "." ^ field + | CL_field (clexp, field) -> string_of_clexp clexp ^ "." ^ string_of_uid field | CL_addr clexp -> string_of_clexp clexp ^ "*" | CL_tuple (clexp, n) -> string_of_clexp clexp ^ "." ^ string_of_int n | CL_void -> "void" @@ -159,9 +119,9 @@ module Ir_formatter = struct | I_copy (clexp, cval) -> add_instr n buf indent (string_of_clexp clexp ^ " = " ^ C.value cval) | I_funcall (clexp, false, id, args) -> - add_instr n buf indent (string_of_clexp clexp ^ " = " ^ zencode_id id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")") + add_instr n buf indent (string_of_clexp clexp ^ " = " ^ string_of_uid id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")") | I_funcall (clexp, true, id, args) -> - add_instr n buf indent (string_of_clexp clexp ^ " = $" ^ zencode_id id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")") + add_instr n buf indent (string_of_clexp clexp ^ " = $" ^ string_of_uid id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")") | I_return cval -> add_instr n buf indent (C.keyword "return" ^ " " ^ C.value cval) | I_comment str -> @@ -185,6 +145,9 @@ module Ir_formatter = struct let id_ctyp (id, ctyp) = sprintf "%s: %s" (zencode_id id) (C.typ ctyp) + let uid_ctyp (uid, ctyp) = + sprintf "%s: %s" (string_of_uid uid) (C.typ ctyp) + let output_def buf = function | CDEF_reg_dec (id, ctyp, _) -> Buffer.add_string buf (sprintf "%s %s : %s" (C.keyword "register") (zencode_id id) (C.typ ctyp)) @@ -203,9 +166,9 @@ module Ir_formatter = struct | CDEF_type (CTD_enum (id, ids)) -> Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "enum") (zencode_id id) (Util.string_of_list ",\n " zencode_id ids)) | CDEF_type (CTD_struct (id, ids)) -> - Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "struct") (zencode_id id) (Util.string_of_list ",\n " id_ctyp ids)) + Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "struct") (zencode_id id) (Util.string_of_list ",\n " uid_ctyp ids)) | CDEF_type (CTD_variant (id, ids)) -> - Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "union") (zencode_id id) (Util.string_of_list ",\n " id_ctyp ids)) + Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "union") (zencode_id id) (Util.string_of_list ",\n " uid_ctyp ids)) | CDEF_let (_, bindings, instrs) -> let instrs = C.modify_instrs instrs in let label_map = C.make_label_map instrs in @@ -220,8 +183,7 @@ module Ir_formatter = struct output_def buf def; Buffer.add_string buf "\n\n"; output_defs buf defs - | [] -> - Buffer.add_char buf '\n' + | [] -> () end end @@ -249,6 +211,7 @@ module Flat_ir_config : Ir_formatter.Config = struct let modify_instrs instrs = let open Jib_optimize in + reset_flat_counter (); instrs |> flatten_instrs |> remove_clear diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml index e7cb70da..323f3cd0 100644 --- a/src/jib/jib_optimize.ml +++ b/src/jib/jib_optimize.ml @@ -50,6 +50,7 @@ open Ast_util open Jib +open Jib_compile open Jib_util let optimize_unit instrs = @@ -81,6 +82,9 @@ let optimize_unit instrs = filter_instrs non_pointless_copy (map_instr_list unit_instr instrs) let flat_counter = ref 0 + +let reset_flat_counter () = flat_counter := 0 + let flat_id orig_id = let id = mk_id (string_of_name ~zencode:false orig_id ^ "_l#" ^ string_of_int !flat_counter) in incr flat_counter; @@ -295,8 +299,8 @@ let inline cdefs should_inline instrs = in let rec inline_instr = function - | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline function_id -> - begin match find_function function_id cdefs with + | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline (fst function_id) -> + begin match find_function (fst function_id) cdefs with | Some (None, ids, body) -> incr inlines; incr label_count; @@ -386,3 +390,155 @@ let rec remove_clear = function | I_aux (I_clear _, _) :: instrs -> remove_clear instrs | instr :: instrs -> instr :: remove_clear instrs | [] -> [] + +let remove_tuples cdefs ctx = + let already_removed = ref CTSet.empty in + let rec all_tuples = function + | CT_tup ctyps as ctyp -> + CTSet.add ctyp (List.fold_left CTSet.union CTSet.empty (List.map all_tuples ctyps)) + | CT_struct (_, id_ctyps) | CT_variant (_, id_ctyps) -> + List.fold_left (fun cts (_, ctyp) -> CTSet.union (all_tuples ctyp) cts) CTSet.empty id_ctyps + | CT_list ctyp | CT_vector (_, ctyp) | CT_ref ctyp -> + all_tuples ctyp + | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ + | CT_unit | CT_bool | CT_real | CT_bit | CT_poly | CT_string | CT_enum _ -> + CTSet.empty + in + let rec tuple_depth = function + | CT_tup ctyps as ctyp -> + 1 + List.fold_left (fun d ctyp -> max d (tuple_depth ctyp)) 0 ctyps + | CT_struct (_, id_ctyps) | CT_variant (_, id_ctyps) -> + List.fold_left (fun d (_, ctyp) -> max (tuple_depth ctyp) d) 0 id_ctyps + | CT_list ctyp | CT_vector (_, ctyp) | CT_ref ctyp -> + tuple_depth ctyp + | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ + | CT_unit | CT_bool | CT_real | CT_bit | CT_poly | CT_string | CT_enum _ -> + 0 + in + let rec fix_tuples = function + | CT_tup ctyps -> + let ctyps = List.map fix_tuples ctyps in + let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in + CT_struct (mk_id name, List.mapi (fun n ctyp -> (mk_id (name ^ string_of_int n), []), ctyp) ctyps) + | CT_struct (id, id_ctyps) -> + CT_struct (id, List.map (fun (id, ctyp) -> id, fix_tuples ctyp) id_ctyps) + | CT_variant (id, id_ctyps) -> + CT_variant (id, List.map (fun (id, ctyp) -> id, fix_tuples ctyp) id_ctyps) + | CT_list ctyp -> CT_list (fix_tuples ctyp) + | CT_vector (d, ctyp) -> CT_vector (d, fix_tuples ctyp) + | CT_ref ctyp -> CT_ref (fix_tuples ctyp) + | (CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ + | CT_unit | CT_bool | CT_real | CT_bit | CT_poly | CT_string | CT_enum _) as ctyp -> + ctyp + in + let rec fix_cval = function + | V_id (id, ctyp) -> V_id (id, ctyp) + | V_ref (id, ctyp) -> V_ref (id, ctyp) + | V_lit (vl, ctyp) -> V_lit (vl, ctyp) + | V_ctor_kind (cval, id, unifiers, ctyp) -> + V_ctor_kind (fix_cval cval, id, unifiers, ctyp) + | V_ctor_unwrap (id, cval, unifiers, ctyp) -> + V_ctor_unwrap (id, fix_cval cval, unifiers, ctyp) + | V_tuple_member (cval, _, n) -> + let ctyp = fix_tuples (cval_ctyp cval) in + let cval = fix_cval cval in + let field = match ctyp with + | CT_struct (id, _) -> + mk_id (string_of_id id ^ string_of_int n) + | _ -> assert false + in + V_field (cval, (field, [])) + | V_call (op, cvals) -> + V_call (op, List.map (fix_cval) cvals) + | V_field (cval, field) -> + V_field (fix_cval cval, field) + | V_struct (fields, ctyp) -> V_struct (List.map (fun (id, cval) -> id, fix_cval cval) fields, ctyp) + | V_poly (cval, ctyp) -> V_poly (fix_cval cval, ctyp) + in + let rec fix_clexp = function + | CL_id (id, ctyp) -> CL_id (id, ctyp) + | CL_addr clexp -> CL_addr (fix_clexp clexp) + | CL_tuple (clexp, n) -> + let ctyp = fix_tuples (clexp_ctyp clexp) in + let clexp = fix_clexp clexp in + let field = match ctyp with + | CT_struct (id, _) -> + mk_id (string_of_id id ^ string_of_int n) + | _ -> assert false + in + CL_field (clexp, (field, [])) + | CL_field (clexp, field) -> CL_field (fix_clexp clexp, field) + | CL_void -> CL_void + | CL_rmw (read, write, ctyp) -> CL_rmw (read, write, ctyp) + in + let rec fix_instr_aux = function + | I_funcall (clexp, extern, id, args) -> + I_funcall (fix_clexp clexp, extern, id, List.map fix_cval args) + | I_copy (clexp, cval) -> I_copy (fix_clexp clexp, fix_cval cval) + | I_init (ctyp, id, cval) -> I_init (ctyp, id, fix_cval cval) + | I_reinit (ctyp, id, cval) -> I_reinit (ctyp, id, fix_cval cval) + | I_jump (cval, label) -> I_jump (fix_cval cval, label) + | I_throw cval -> I_throw (fix_cval cval) + | I_return cval -> I_return (fix_cval cval) + | I_if (cval, then_instrs, else_instrs, ctyp) -> + I_if (fix_cval cval, List.map fix_instr then_instrs, List.map fix_instr else_instrs, ctyp) + | I_block instrs -> I_block (List.map fix_instr instrs) + | I_try_block instrs -> I_try_block (List.map fix_instr instrs) + | (I_goto _ | I_label _ | I_decl _ | I_clear _ | I_end _ | I_comment _ + | I_reset _ | I_undefined _ | I_match_failure | I_raw _) as instr -> instr + and fix_instr (I_aux (instr, aux)) = I_aux (fix_instr_aux instr, aux) + in + let fix_conversions = function + | I_aux (I_copy (clexp, cval), ((_, l) as aux)) as instr -> + begin match clexp_ctyp clexp, cval_ctyp cval with + | CT_tup lhs_ctyps, CT_tup rhs_ctyps when List.length lhs_ctyps = List.length rhs_ctyps -> + let elems = List.length lhs_ctyps in + if List.for_all2 ctyp_equal lhs_ctyps rhs_ctyps then + [instr] + else + List.mapi (fun n _ -> icopy l (CL_tuple (clexp, n)) (V_tuple_member (cval, elems, n))) lhs_ctyps + | _ -> [instr] + end + | instr -> [instr] + in + let fix_ctx ctx = + { ctx with + records = Bindings.map (UBindings.map fix_tuples) ctx.records; + variants = Bindings.map (UBindings.map fix_tuples) ctx.variants; + valspecs = Bindings.map (fun (ctyps, ctyp) -> List.map fix_tuples ctyps, fix_tuples ctyp) ctx.valspecs; + locals = Bindings.map (fun (mut, ctyp) -> mut, fix_tuples ctyp) ctx.locals + } + in + let to_struct = function + | CT_tup ctyps -> + let ctyps = List.map fix_tuples ctyps in + let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in + CDEF_type (CTD_struct (mk_id name, List.mapi (fun n ctyp -> (mk_id (name ^ string_of_int n), []), ctyp) ctyps)) + | _ -> assert false + in + let rec go acc = function + | cdef :: cdefs -> + let tuples = CTSet.fold (fun ctyp -> CTSet.union (all_tuples ctyp)) (cdef_ctyps cdef) CTSet.empty in + let tuples = CTSet.diff tuples !already_removed in + (* In the case where we have ((x, y), z) and (x, y) we need to + generate (x, y) first, so we sort by the depth of nesting in + the tuples (note we build acc in reverse order) *) + let sorted_tuples = + CTSet.elements tuples + |> List.map (fun ctyp -> tuple_depth ctyp, ctyp) + |> List.sort (fun (d1, _) (d2, _) -> compare d2 d1) + |> List.map snd + in + let structs = List.map to_struct sorted_tuples in + already_removed := CTSet.union tuples !already_removed; + let cdef = + cdef + |> cdef_concatmap_instr fix_conversions + |> cdef_map_instr fix_instr + |> cdef_map_ctyp fix_tuples + in + go (cdef :: structs @ acc) cdefs + | [] -> List.rev acc + in + go [] cdefs, + fix_ctx ctx diff --git a/src/jib/jib_optimize.mli b/src/jib/jib_optimize.mli index a69b45b7..7dae53a9 100644 --- a/src/jib/jib_optimize.mli +++ b/src/jib/jib_optimize.mli @@ -60,6 +60,7 @@ val optimize_unit : instr list -> instr list instructions, prodcing a flat list of instructions. *) val flatten_instrs : instr list -> instr list val flatten_cdef : cdef -> cdef +val reset_flat_counter : unit -> unit val unique_per_function_ids : cdef list -> cdef list @@ -75,3 +76,5 @@ val remove_unused_labels : instr list -> instr list val remove_dead_after_goto : instr list -> instr list val remove_dead_code : instr list -> instr list + +val remove_tuples : cdef list -> Jib_compile.ctx -> cdef list * Jib_compile.ctx diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index 7a827ece..e9ec8260 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -62,7 +62,9 @@ module IntMap = Map.Make(struct type t = int let compare = compare end) let zencode_upper_id id = Util.zencode_upper_string (string_of_id id) let zencode_id id = Util.zencode_string (string_of_id id) let zencode_name id = string_of_name ~deref_current_exception:false ~zencode:true id - +let zencode_uid (id, ctyps) = + Util.zencode_string (string_of_id id ^ "#" ^ Util.string_of_list "_" string_of_ctyp ctyps) + let opt_ignore_overflow = ref false let opt_auto = ref false @@ -158,9 +160,9 @@ let rec smt_ctyp ctx = function | CT_enum (id, elems) -> mk_enum (zencode_upper_id id) (List.map zencode_id elems) | CT_struct (id, fields) -> - mk_record (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) fields) + mk_record (zencode_upper_id id) (List.map (fun (uid, ctyp) -> (zencode_uid uid, smt_ctyp ctx ctyp)) fields) | CT_variant (id, ctors) -> - mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) ctors) + mk_variant (zencode_upper_id id) (List.map (fun (uid, ctyp) -> (zencode_uid uid, smt_ctyp ctx ctyp)) ctors) | CT_tup ctyps -> ctx.tuple_sizes := IntSet.add (List.length ctyps) !(ctx.tuple_sizes); Tuple (List.map (smt_ctyp ctx) ctyps) @@ -293,14 +295,14 @@ let rec smt_cval ctx cval = | V_field (union, field) -> begin match cval_ctyp union with | CT_struct (struct_id, _) -> - Fn (zencode_upper_id struct_id ^ "_" ^ field, [smt_cval ctx union]) + Fn (zencode_upper_id struct_id ^ "_" ^ zencode_uid field, [smt_cval ctx union]) | _ -> failwith "Field for non-struct type" end | V_struct (fields, ctyp) -> begin match ctyp with | CT_struct (struct_id, field_ctyps) -> let set_field (field, cval) = - match Util.assoc_compare_opt Id.compare field field_ctyps with + match Util.assoc_compare_opt UId.compare field field_ctyps with | None -> failwith "Field type not found" | Some ctyp when ctyp_equal (cval_ctyp cval) ctyp -> smt_cval ctx cval @@ -324,7 +326,7 @@ let rec smt_cval ctx cval = end | _ -> assert false end - | cval -> failwith ("Unrecognised cval " ^ Jib_ir.string_of_cval cval) + | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval) let add_event ctx ev smt = let stack = event_stack ctx ev in @@ -1184,12 +1186,12 @@ let smt_ctype_def ctx = function | CTD_struct (id, fields) -> [declare_datatypes (mk_record (zencode_upper_id id) - (List.map (fun (field, ctyp) -> zencode_upper_id id ^ "_" ^ zencode_id field, smt_ctyp ctx ctyp) fields))] + (List.map (fun (field, ctyp) -> zencode_upper_id id ^ "_" ^ zencode_uid field, smt_ctyp ctx ctyp) fields))] | CTD_variant (id, ctors) -> [declare_datatypes (mk_variant (zencode_upper_id id) - (List.map (fun (ctor, ctyp) -> zencode_id ctor, smt_ctyp ctx ctyp) ctors))] + (List.map (fun (ctor, ctyp) -> zencode_uid ctor, smt_ctyp ctx ctyp) ctors))] let rec generate_ctype_defs ctx = function | CDEF_type ctd :: cdefs -> smt_ctype_def ctx ctd :: generate_ctype_defs ctx cdefs @@ -1274,8 +1276,8 @@ let rec ctyp_of_typ ctx typ = | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> CT_ref (ctyp_of_typ ctx typ) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> Bindings.bindings) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> Bindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> UBindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> UBindings.bindings) | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) @@ -1524,10 +1526,10 @@ let rmw_modify smt = function begin match ctyp with | CT_struct (struct_id, fields) -> let set_field (field', _) = - if Util.zencode_string field = zencode_id field' then + if UId.compare field field' = 0 then smt else - Fn (zencode_upper_id struct_id ^ "_" ^ zencode_id field', [Var (rmw_read clexp)]) + Fn (zencode_upper_id struct_id ^ "_" ^ zencode_uid field', [Var (rmw_read clexp)]) in Fn (zencode_upper_id struct_id, List.map set_field fields) | _ -> @@ -1557,8 +1559,8 @@ let smt_instr ctx = let open Type_check in function | I_aux (I_funcall (CL_id (id, ret_ctyp), extern, function_id, args), (_, l)) -> - if Env.is_extern function_id ctx.tc_env "c" && not extern then - let name = Env.get_extern function_id ctx.tc_env "c" in + if Env.is_extern (fst function_id) ctx.tc_env "c" && not extern then + let name = Env.get_extern (fst function_id) ctx.tc_env "c" in if name = "sqrt_real" then begin match args with | [v] -> builtin_sqrt_real ctx (zencode_name id) v @@ -1609,9 +1611,9 @@ let smt_instr ctx = else let value = smt_builtin ctx name args ret_ctyp in [define_const ctx id ret_ctyp value] - else if extern && string_of_id function_id = "internal_vector_init" then + else if extern && string_of_id (fst function_id) = "internal_vector_init" then [declare_const ctx id ret_ctyp] - else if extern && string_of_id function_id = "internal_vector_update" then + else if extern && string_of_id (fst function_id) = "internal_vector_update" then begin match args with | [vec; i; x] -> let sz = int_size ctx (cval_ctyp i) in @@ -1620,7 +1622,7 @@ let smt_instr ctx = | _ -> Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update" end - else if string_of_id function_id = "sail_assert" then + else if string_of_id (fst function_id) = "sail_assert" then begin match args with | [assertion; _] -> let smt = smt_cval ctx assertion in @@ -1629,7 +1631,7 @@ let smt_instr ctx = | _ -> Reporting.unreachable l __POS__ "Bad arguments for assertion" end - else if string_of_id function_id = "sail_assume" then + else if string_of_id (fst function_id) = "sail_assume" then begin match args with | [assumption] -> let smt = smt_cval ctx assumption in @@ -1640,9 +1642,9 @@ let smt_instr ctx = end else if not extern then let smt_args = List.map (smt_cval ctx) args in - [define_const ctx id ret_ctyp (Ctor (zencode_id function_id, smt_args))] + [define_const ctx id ret_ctyp (Ctor (zencode_id (fst function_id), smt_args))] else - failwith ("Unrecognised function " ^ string_of_id function_id) + failwith ("Unrecognised function " ^ string_of_id (fst function_id)) | I_aux (I_copy (CL_addr (CL_id (_, _)), _), (_, l)) -> Reporting.unreachable l __POS__ "Register reference write should be re-written by now" @@ -1850,7 +1852,7 @@ let smt_header ctx cdefs = let expand_reg_deref env register_map = function | I_aux (I_funcall (clexp, false, function_id, [reg_ref]), (_, l)) as instr -> let open Type_check in - begin match (if Env.is_extern function_id env "smt" then Some (Env.get_extern function_id env "smt") else None) with + begin match (if Env.is_extern (fst function_id) env "smt" then Some (Env.get_extern (fst function_id) env "smt") else None) with | Some "reg_deref" -> begin match cval_ctyp reg_ref with | CT_ref reg_ctyp -> diff --git a/src/jib/jib_smt_fuzz.ml b/src/jib/jib_smt_fuzz.ml index 28ec40b9..846d0178 100644 --- a/src/jib/jib_smt_fuzz.ml +++ b/src/jib/jib_smt_fuzz.ml @@ -187,7 +187,7 @@ let fuzz_cdef ctx all_cdefs = function let jib = let gs = ngensym () in - [ifuncall (CL_id (gs, ret_ctyp)) id (List.map snd values)] + [ifuncall (CL_id (gs, ret_ctyp)) (id, []) (List.map snd values)] @ gen_assertion ret_ctyp result gs @ [iend ()] in diff --git a/src/jib/jib_ssa.ml b/src/jib/jib_ssa.ml index 840dea97..ba4337b0 100644 --- a/src/jib/jib_ssa.ml +++ b/src/jib/jib_ssa.ml @@ -722,7 +722,7 @@ let string_of_ssainstr = function | Phi (id, ctyp, args) -> string_of_name id ^ " : " ^ string_of_ctyp ctyp ^ " = φ(" ^ Util.string_of_list ", " string_of_name args ^ ")" | Pi cvals -> - "π(" ^ Util.string_of_list ", " (fun v -> String.escaped (Jib_ir.string_of_cval v)) cvals ^ ")" + "π(" ^ Util.string_of_list ", " (fun v -> String.escaped (string_of_cval v)) cvals ^ ")" let string_of_phis = function | [] -> "" diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml index a1dac297..caef7ecb 100644 --- a/src/jib/jib_util.ml +++ b/src/jib/jib_util.ml @@ -257,20 +257,6 @@ let rec instr_rename from_id to_id (I_aux (instr, aux)) = (* 1. Instruction pretty printer *) (**************************************************************************) -let string_of_value = function - | VL_bits ([], _) -> "UINT64_C(0)" - | VL_bits (bs, true) -> "UINT64_C(" ^ Sail2_values.show_bitlist bs ^ ")" - | VL_bits (bs, false) -> "UINT64_C(" ^ Sail2_values.show_bitlist (List.rev bs) ^ ")" - | VL_int i -> Big_int.to_string i ^ "l" - | VL_bool true -> "true" - | VL_bool false -> "false" - | VL_null -> "NULL" - | VL_unit -> "UNIT" - | VL_bit Sail2_values.B0 -> "UINT64_C(0)" - | VL_bit Sail2_values.B1 -> "UINT64_C(1)" - | VL_bit Sail2_values.BU -> failwith "Undefined bit found in value" - | VL_real str -> str - | VL_string str -> "\"" ^ str ^ "\"" let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) = let ssa_num n = if n = -1 then "" else ("/" ^ string_of_int n) in @@ -341,6 +327,11 @@ let rec string_of_ctyp = function | CT_ref ctyp -> "&(" ^ string_of_ctyp ctyp ^ ")" | CT_poly -> "*" +and string_of_uid (id, ctyps) = + match ctyps with + | [] -> Util.zencode_string (string_of_id id) + | _ -> Util.zencode_string (string_of_id id ^ "#" ^ Util.string_of_list "_" string_of_ctyp ctyps) + (** This function is like string_of_ctyp, but recursively prints all constructors in variants and structs. Used for debug output. *) and full_string_of_ctyp = function @@ -348,12 +339,12 @@ and full_string_of_ctyp = function | CT_struct (id, ctors) -> "struct " ^ string_of_id id ^ "{" - ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors + ^ Util.string_of_list ", " (fun ((id, _), ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors ^ "}" | CT_variant (id, ctors) -> "union " ^ string_of_id id ^ "{" - ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors + ^ Util.string_of_list ", " (fun ((id, _), ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors ^ "}" | CT_vector (true, ctyp) -> "vector(dec, " ^ full_string_of_ctyp ctyp ^ ")" | CT_vector (false, ctyp) -> "vector(inc, " ^ full_string_of_ctyp ctyp ^ ")" @@ -361,6 +352,46 @@ and full_string_of_ctyp = function | CT_ref ctyp -> "ref(" ^ full_string_of_ctyp ctyp ^ ")" | ctyp -> string_of_ctyp ctyp +let string_of_value = function + | VL_bits ([], _) -> "empty" + | VL_bits (bs, true) -> Sail2_values.show_bitlist bs + | VL_bits (bs, false) -> Sail2_values.show_bitlist (List.rev bs) + | VL_int i -> Big_int.to_string i + | VL_bool true -> "true" + | VL_bool false -> "false" + | VL_null -> "NULL" + | VL_unit -> "()" + | VL_bit Sail2_values.B0 -> "bitzero" + | VL_bit Sail2_values.B1 -> "bitone" + | VL_bit Sail2_values.BU -> "bitundef" + | VL_real str -> str + | VL_string str -> "\"" ^ str ^ "\"" + +let rec string_of_cval = function + | V_id (id, ctyp) -> string_of_name id + | V_ref (id, _) -> "&" ^ string_of_name id + | V_lit (vl, ctyp) -> string_of_value vl + | V_call (op, cvals) -> + Printf.sprintf "%s(%s)" (string_of_op op) (Util.string_of_list ", " string_of_cval cvals) + | V_field (f, field) -> + Printf.sprintf "%s.%s" (string_of_cval f) (string_of_uid field) + | V_tuple_member (f, _, n) -> + Printf.sprintf "%s.ztup%d" (string_of_cval f) n + | V_ctor_kind (f, ctor, [], _) -> + string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor) + | V_ctor_kind (f, ctor, unifiers, _) -> + string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) + | V_ctor_unwrap (ctor, f, [], _) -> + Printf.sprintf "%s as %s" (string_of_cval f) (string_of_id ctor) + | V_ctor_unwrap (ctor, f, unifiers, _) -> + Printf.sprintf "%s as %s" + (string_of_cval f) + (Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)) + | V_struct (fields, _) -> + Printf.sprintf "{%s}" + (Util.string_of_list ", " (fun (field, cval) -> string_of_uid field ^ " = " ^ string_of_cval cval) fields) + | V_poly (f, _) -> string_of_cval f + let rec map_ctyp f = function | (CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string | CT_poly | CT_enum _) as ctyp -> f ctyp @@ -368,8 +399,10 @@ let rec map_ctyp f = function | CT_ref ctyp -> f (CT_ref (map_ctyp f ctyp)) | CT_vector (direction, ctyp) -> f (CT_vector (direction, map_ctyp f ctyp)) | CT_list ctyp -> f (CT_list (map_ctyp f ctyp)) - | CT_struct (id, ctors) -> f (CT_struct (id, List.map (fun (id, ctyp) -> id, map_ctyp f ctyp) ctors)) - | CT_variant (id, ctors) -> f (CT_variant (id, List.map (fun (id, ctyp) -> id, map_ctyp f ctyp) ctors)) + | CT_struct (id, ctors) -> + f (CT_struct (id, List.map (fun ((id, ctyps), ctyp) -> (id, List.map (map_ctyp f) ctyps), map_ctyp f ctyp) ctors)) + | CT_variant (id, ctors) -> + f (CT_variant (id, List.map (fun ((id, ctyps), ctyp) -> (id, List.map (map_ctyp f) ctyps), map_ctyp f ctyp) ctors)) let rec ctyp_equal ctyp1 ctyp2 = match ctyp1, ctyp2 with @@ -469,6 +502,23 @@ end module CTSet = Set.Make(CT) module CTMap = Map.Make(CT) +module UId = struct + type t = uid + let lex_ord c1 c2 = if c1 = 0 then c2 else c1 + let rec compare_ctyps ctyps1 ctyps2 = + match ctyps1, ctyps2 with + | (ctyp1 :: ctyps1), (ctyp2 :: ctyps2) -> + lex_ord (CT.compare ctyp1 ctyp2) (compare_ctyps ctyps1 ctyps2) + | [], [] -> 0 + | [], _ -> 1 + | _, [] -> -1 + let compare (id1, ctyps1) (id2, ctyps2) = + let lex_ord c1 c2 = if c1 = 0 then c2 else c1 in + lex_ord (Id.compare id1 id2) (compare_ctyps ctyps1 ctyps2) +end + +module UBindings = Map.Make(UId) + let rec ctyp_unify ctyp1 ctyp2 = match ctyp1, ctyp2 with | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> @@ -609,7 +659,7 @@ let instr_typed_writes (I_aux (aux, _)) = let rec map_clexp_ctyp f = function | CL_id (id, ctyp) -> CL_id (id, f ctyp) | CL_rmw (read, write, ctyp) -> CL_rmw (read, write, f ctyp) - | CL_field (clexp, field) -> CL_field (map_clexp_ctyp f clexp, field) + | CL_field (clexp, (id, ctyps)) -> CL_field (map_clexp_ctyp f clexp, (id, List.map f ctyps)) | CL_tuple (clexp, n) -> CL_tuple (map_clexp_ctyp f clexp, n) | CL_addr clexp -> CL_addr (map_clexp_ctyp f clexp) | CL_void -> CL_void @@ -626,9 +676,10 @@ let rec map_cval_ctyp f = function V_tuple_member (map_cval_ctyp f cval, i, j) | V_call (op, cvals) -> V_call (op, List.map (map_cval_ctyp f) cvals) - | V_field (cval, field) -> - V_field (map_cval_ctyp f cval, field) - | V_struct (fields, ctyp) -> V_struct (List.map (fun (id, cval) -> id, map_cval_ctyp f cval) fields, f ctyp) + | V_field (cval, (id, ctyps)) -> + V_field (map_cval_ctyp f cval, (id, List.map f ctyps)) + | V_struct (fields, ctyp) -> + V_struct (List.map (fun ((id, ctyps), cval) -> (id, List.map f ctyps), map_cval_ctyp f cval) fields, f ctyp) | V_poly (cval, ctyp) -> V_poly (map_cval_ctyp f cval, f ctyp) let rec map_instr_ctyp f (I_aux (instr, aux)) = @@ -638,8 +689,8 @@ let rec map_instr_ctyp f (I_aux (instr, aux)) = | I_if (cval, then_instrs, else_instrs, ctyp) -> I_if (map_cval_ctyp f cval, List.map (map_instr_ctyp f) then_instrs, List.map (map_instr_ctyp f) else_instrs, f ctyp) | I_jump (cval, label) -> I_jump (map_cval_ctyp f cval, label) - | I_funcall (clexp, extern, id, cvals) -> - I_funcall (map_clexp_ctyp f clexp, extern, id, List.map (map_cval_ctyp f) cvals) + | I_funcall (clexp, extern, (id, ctyps), cvals) -> + I_funcall (map_clexp_ctyp f clexp, extern, (id, List.map f ctyps), List.map (map_cval_ctyp f) cvals) | I_copy (clexp, cval) -> I_copy (map_clexp_ctyp f clexp, map_cval_ctyp f cval) | I_clear (ctyp, id) -> I_clear (f ctyp, id) | I_return cval -> I_return (map_cval_ctyp f cval) @@ -669,6 +720,21 @@ let rec map_instr f (I_aux (instr, aux)) = in f (I_aux (instr, aux)) +(** Map over each instruction within an instruction, bottom-up *) +let rec concatmap_instr f (I_aux (instr, aux)) = + let instr = match instr with + | I_decl _ | I_init _ | I_reset _ | I_reinit _ + | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _ + | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_undefined _ | I_end _ -> instr + | I_if (cval, instrs1, instrs2, ctyp) -> + I_if (cval, List.concat (List.map (concatmap_instr f) instrs1), List.concat (List.map (concatmap_instr f) instrs2), ctyp) + | I_block instrs -> + I_block (List.concat (List.map (concatmap_instr f) instrs)) + | I_try_block instrs -> + I_try_block (List.concat (List.map (concatmap_instr f) instrs)) + in + f (I_aux (instr, aux)) + (** Iterate over each instruction within an instruction, bottom-up *) let rec iter_instr f (I_aux (instr, aux)) = match instr with @@ -691,10 +757,25 @@ let cdef_map_instr f = function | CDEF_spec (id, ctyps, ctyp) -> CDEF_spec (id, ctyps, ctyp) | CDEF_type tdef -> CDEF_type tdef +(** Map over each instruction in a cdef using concatmap_instr *) +let cdef_concatmap_instr f = function + | CDEF_reg_dec (id, ctyp, instrs) -> + CDEF_reg_dec (id, ctyp, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_let (n, bindings, instrs) -> + CDEF_let (n, bindings, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_fundef (id, heap_return, args, instrs) -> + CDEF_fundef (id, heap_return, args, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_startup (id, instrs) -> + CDEF_startup (id, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_finish (id, instrs) -> + CDEF_finish (id, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_spec (id, ctyps, ctyp) -> CDEF_spec (id, ctyps, ctyp) + | CDEF_type tdef -> CDEF_type tdef + let ctype_def_map_ctyp f = function | CTD_enum (id, ids) -> CTD_enum (id, ids) - | CTD_struct (id, ctors) -> CTD_struct (id, List.map (fun (field, ctyp) -> (field, f ctyp)) ctors) - | CTD_variant (id, ctors) -> CTD_variant (id, List.map (fun (field, ctyp) -> (field, f ctyp)) ctors) + | CTD_struct (id, ctors) -> CTD_struct (id, List.map (fun ((id, ctyps), ctyp) -> ((id, List.map f ctyps), f ctyp)) ctors) + | CTD_variant (id, ctors) -> CTD_variant (id, List.map (fun ((id, ctyps), ctyp) -> ((id, List.map f ctyps), f ctyp)) ctors) (** Map over each ctyp in a cdef using map_instr_ctyp *) let cdef_map_ctyp f = function @@ -834,8 +915,8 @@ and cval_ctyp = function begin match cval_ctyp cval with | CT_struct (id, ctors) -> begin - try snd (List.find (fun (id, ctyp) -> Util.zencode_string (string_of_id id) = field) ctors) with - | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ field) + try snd (List.find (fun (uid, ctyp) -> UId.compare uid field = 0) ctors) with + | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ string_of_uid field) end | ctyp -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Inavlid type for V_field " ^ full_string_of_ctyp ctyp) end @@ -849,8 +930,8 @@ let rec clexp_ctyp = function begin match clexp_ctyp clexp with | CT_struct (id, ctors) -> begin - try snd (List.find (fun (id, ctyp) -> string_of_id id = field) ctors) with - | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ field) + try snd (List.find (fun (uid, ctyp) -> UId.compare uid field = 0) ctors) with + | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ string_of_uid field) end | ctyp -> failwith ("Bad ctyp for CL_field " ^ string_of_ctyp ctyp) end @@ -880,8 +961,9 @@ let rec instr_ctyps (I_aux (instr, aux)) = CTSet.union (instrs_ctyps instrs1) (instrs_ctyps instrs2) |> CTSet.add (cval_ctyp cval) |> CTSet.add ctyp - | I_funcall (clexp, _, _, cvals) -> + | I_funcall (clexp, _, (_, ctyps), cvals) -> List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty (List.map cval_ctyp cvals) + |> CTSet.union (CTSet.of_list ctyps) |> CTSet.add (clexp_ctyp clexp) | I_copy (clexp, cval) -> CTSet.add (clexp_ctyp clexp) (CTSet.singleton (cval_ctyp cval)) diff --git a/src/sail.ml b/src/sail.ml index da711e8d..e792e652 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -349,9 +349,6 @@ let options = Arg.align ([ ( "-ddump_rewrite_ast", Arg.String (fun l -> opt_ddump_rewrite_ast := Some (l, 0); Specialize.opt_ddump_spec_ast := Some (l, 0)), "<prefix> (debug) dump the ast after each rewriting step to <prefix>_<i>.lem"); - ( "-ddump_flow_graphs", - Arg.Set Jib_compile.opt_debug_flow_graphs, - " (debug) dump flow analysis for Sail functions when compiling to C"); ( "-ddump_smt_graphs", Arg.Set Jib_smt.opt_debug_graphs, " (debug) dump flow analysis for properties when generating SMT"); |
