diff options
| author | Alasdair | 2019-10-26 04:59:49 +0100 |
|---|---|---|
| committer | Alasdair | 2019-10-28 13:38:10 +0000 |
| commit | 7f9371921cfcec819d9e0c778f8b817fb1566bce (patch) | |
| tree | 0bdf7c59c3192884f0e706baa46805abcece6fb3 | |
| parent | 5bcbe357c72382c1076ea7fd7c3ca6ea9f2f035c (diff) | |
Some C backend refactoring
Make it so that jib_compile.ml never relies on specific string encodings
for various constructs in C. Previously this happened when
monomorphisation occured for union constructors and fields, i.e.
x.foo -> x.zfoo_bitsz632z7
Now identifiers that can be modified are represented as (id, ctyp list)
tuples, so we can keep the types
x.foo -> x.foo::<bits(32)>
This then enables us to do jib IR -> jib IR rewrites that modify types
In particular there is now a rewrite that removes tuples as an IR->IR
pass rather than doing it ad-hoc in the C code generation, although this
is not on by default
Note that this change seems to have triggered an Ott bug so jib.lem is
now checked in and not generated from Ott
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | language/jib.ott | 10 | ||||
| -rw-r--r-- | src/Makefile | 4 | ||||
| -rw-r--r-- | src/jib.lem | 140 | ||||
| -rw-r--r-- | src/jib/anf.ml | 4 | ||||
| -rw-r--r-- | src/jib/anf.mli | 4 | ||||
| -rw-r--r-- | src/jib/c_backend.ml | 149 | ||||
| -rw-r--r-- | src/jib/jib_compile.ml | 114 | ||||
| -rw-r--r-- | src/jib/jib_compile.mli | 9 | ||||
| -rw-r--r-- | src/jib/jib_ir.ml | 57 | ||||
| -rw-r--r-- | src/jib/jib_optimize.ml | 160 | ||||
| -rw-r--r-- | src/jib/jib_optimize.mli | 3 | ||||
| -rw-r--r-- | src/jib/jib_smt.ml | 44 | ||||
| -rw-r--r-- | src/jib/jib_smt_fuzz.ml | 2 | ||||
| -rw-r--r-- | src/jib/jib_ssa.ml | 2 | ||||
| -rw-r--r-- | src/jib/jib_util.ml | 144 | ||||
| -rw-r--r-- | src/sail.ml | 3 |
17 files changed, 592 insertions, 258 deletions
@@ -38,7 +38,6 @@ lib/hol/sail-heap /src/sail.docdir /src/ast.lem /src/ast.ml -/src/jib.lem /src/jib.ml /src/manifest.ml diff --git a/language/jib.ott b/language/jib.ott index dfda3bbc..1d097c80 100644 --- a/language/jib.ott +++ b/language/jib.ott @@ -106,6 +106,10 @@ cval :: 'V_' ::= %%% IR types +uid :: 'UId_' ::= + {{ lem id * list ctyp }} + {{ ocaml id * ctyp list }} + ctyp :: 'CT_' ::= {{ com C type }} % Integer types @@ -144,7 +148,7 @@ ctyp :: 'CT_' ::= % need to be encoded. | enum id ( id0 , ... , idn ) :: :: enum | struct id ( id0 * ctyp0 , ... , idn * ctypn ) :: :: struct - | variant id ( id0 * ctyp0 , ... , idn * ctypn ) :: :: variant + | variant id ( uid0 * ctyp0 , ... , uidn * ctypn ) :: :: variant % A vector type for non-bit vectors, and a (linked) list type. | vector ( bool , ctyp ) :: :: vector @@ -175,9 +179,9 @@ ctype_def :: 'CTD_' ::= {{ com C type definition }} | enum id = id0 '|' ... '|' idn :: :: enum | struct id = { id0 : ctyp0 , ... , idn : ctypn } :: :: struct - | variant id = { id0 : ctyp0 , ... , idn : ctypn } :: :: variant + | variant id = { uid0 : ctyp0 , ... , uidn : ctypn } :: :: variant -iannot :: 'IA_' ::= +iannot :: '' ::= {{ lem nat * nat * nat }} {{ ocaml int * int * int }} diff --git a/src/Makefile b/src/Makefile index a002d4f3..020d6813 100644 --- a/src/Makefile +++ b/src/Makefile @@ -74,9 +74,6 @@ full: sail lib doc ast.lem: ../language/sail.ott ott -sort false -generate_aux_rules true -o ast.lem -picky_multiple_parses true ../language/sail.ott -jib.lem: ../language/jib.ott ast.lem - ott -sort false -generate_aux_rules true -o jib.lem -picky_multiple_parses true ../language/jib.ott - ast.ml: ast.lem lem -ocaml ast.lem sed -i.bak -f ast.sed ast.ml @@ -137,7 +134,6 @@ clean: -rm -f ast.lem -rm -f ast.ml.bak -rm -f jib.ml - -rm -f jib.lem -rm -f jib.ml.bak -rm -f manifest.ml diff --git a/src/jib.lem b/src/jib.lem new file mode 100644 index 00000000..c8959d10 --- /dev/null +++ b/src/jib.lem @@ -0,0 +1,140 @@ +(* generated by Ott 0.29 from: ../language/jib.ott *) +open import Pervasives + + +open import Ast +open import Value2 + + + +type name = + | Name of id * nat + | Have_exception of nat + | Current_exception of nat + | Return of nat + + +type ctyp = (* C type *) + | CT_lint + | CT_fint of nat + | CT_constant of integer + | CT_lbits of bool + | CT_sbits of nat * bool + | CT_fbits of nat * bool + | CT_unit + | CT_bool + | CT_bit + | CT_string + | CT_real + | CT_tup of list ctyp + | CT_enum of id * list id + | CT_struct of id * list (uid * ctyp) + | CT_variant of id * list (uid * ctyp) + | CT_vector of bool * ctyp + | CT_list of ctyp + | CT_ref of ctyp + | CT_poly + +and uid = id * list ctyp + +type op = + | Bnot + | Bor + | Band + | List_hd + | List_tl + | Bit_to_bool + | Eq + | Neq + | Ilt + | Ilteq + | Igt + | Igteq + | Iadd + | Isub + | Unsigned of nat + | Signed of nat + | Bvnot + | Bvor + | Bvand + | Bvxor + | Bvadd + | Bvsub + | Bvaccess + | Concat + | Zero_extend of nat + | Sign_extend of nat + | Slice of nat + | Sslice of nat + | Set_slice + | Replicate of nat + + +type clexp = + | CL_id of name * ctyp + | CL_rmw of name * name * ctyp + | CL_field of clexp * uid + | CL_addr of clexp + | CL_tuple of clexp * nat + | CL_void + + +type cval = + | V_id of name * ctyp + | V_ref of name * ctyp + | V_lit of vl * ctyp + | V_struct of list (uid * cval) * ctyp + | V_ctor_kind of cval * id * list ctyp * ctyp + | V_ctor_unwrap of id * cval * list ctyp * ctyp + | V_tuple_member of cval * nat * nat + | V_call of op * list cval + | V_field of cval * uid + | V_poly of cval * ctyp + + +type iannot = nat * nat * nat + + +type instr_aux = + | I_decl of ctyp * name + | I_init of ctyp * name * cval + | I_jump of cval * string + | I_goto of string + | I_label of string + | I_funcall of clexp * bool * uid * list cval + | I_copy of clexp * cval + | I_clear of ctyp * name + | I_undefined of ctyp + | I_match_failure + | I_end of name + | I_if of cval * list instr * list instr * ctyp + | I_block of list instr + | I_try_block of list instr + | I_throw of cval + | I_comment of string + | I_raw of string + | I_return of cval + | I_reset of ctyp * name + | I_reinit of ctyp * name * cval + +and instr = + | I_aux of instr_aux * iannot + + +type ctype_def = (* C type definition *) + | CTD_enum of id * list id + | CTD_struct of id * list (uid * ctyp) + | CTD_variant of id * list (uid * ctyp) + + +type cdef = + | CDEF_reg_dec of id * ctyp * list instr + | CDEF_type of ctype_def + | CDEF_let of nat * list (id * ctyp) * list instr + | CDEF_spec of id * list ctyp * ctyp + | CDEF_fundef of id * maybe id * list id * list instr + | CDEF_startup of id * list instr + | CDEF_finish of id * list instr + + + diff --git a/src/jib/anf.ml b/src/jib/anf.ml index 52c4584c..6bc447c0 100644 --- a/src/jib/anf.ml +++ b/src/jib/anf.ml @@ -468,7 +468,7 @@ and pp_aval = function | AV_tuple avals -> parens (separate_map (comma ^^ space) pp_aval avals) | AV_ref (id, lvar) -> string "ref" ^^ space ^^ pp_lvar lvar (pp_id id) | AV_cval (cval, typ) -> - pp_annot typ (string (Jib_ir.string_of_cval cval |> Util.cyan |> Util.clear)) + pp_annot typ (string (string_of_cval cval |> Util.cyan |> Util.clear)) | AV_vector (avals, typ) -> pp_annot typ (string "[" ^^ separate_map (comma ^^ space) pp_aval avals ^^ string "]") | AV_list (avals, typ) -> @@ -482,7 +482,7 @@ let ae_lit lit typ = AE_val (AV_lit (lit, typ)) let is_dead_aexp (AE_aux (_, env, _)) = prove __POS__ env nc_false -let (gensym, _) = symbol_generator "anf" +let (gensym, reset_anf_counter) = symbol_generator "ga" let rec split_block l = function | [exp] -> [], exp diff --git a/src/jib/anf.mli b/src/jib/anf.mli index f28cf420..d01fe146 100644 --- a/src/jib/anf.mli +++ b/src/jib/anf.mli @@ -129,6 +129,10 @@ and 'a aval = | AV_record of ('a aval) Bindings.t * 'a | AV_cval of cval * 'a +(** When ANF translation has to introduce new bindings it uses a +counter to ensure uniqueness. This function resets that counter. *) +val reset_anf_counter : unit -> unit + (** {2 Functions for transforming ANF expressions} *) val aval_typ : typ aval -> typ diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index bf22c5d2..a97b06ef 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -61,8 +61,6 @@ open Anf module Big_int = Nat_big_num -let c_verbosity = ref 0 - let opt_static = ref false let opt_no_main = ref false let opt_memo_cache = ref false @@ -93,16 +91,16 @@ let optimize_fixed_bits = ref false let (gensym, _) = symbol_generator "cb" let ngensym () = name (gensym ()) -let c_debug str = - if !c_verbosity > 0 then prerr_endline (Lazy.force str) else () - let c_error ?loc:(l=Parse_ast.Unknown) message = raise (Reporting.err_general l ("\nC backend: " ^ message)) -let zencode_id = function - | Id_aux (Id str, l) -> Id_aux (Id (Util.zencode_string str), l) - | Id_aux (Operator str, l) -> Id_aux (Id (Util.zencode_string ("op " ^ str)), l) +let zencode_id id = Util.zencode_string (string_of_id id) +let zencode_uid (id, ctyps) = + match ctyps with + | [] -> Util.zencode_string (string_of_id id) + | _ -> Util.zencode_string (string_of_id id ^ "#" ^ Util.string_of_list "_" string_of_ctyp ctyps) + (**************************************************************************) (* 2. Converting sail types to C types *) (**************************************************************************) @@ -172,8 +170,8 @@ let rec ctyp_of_typ ctx typ = | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> CT_ref (ctyp_of_typ ctx typ) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> Bindings.bindings) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> Bindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> UBindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> UBindings.bindings) | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) @@ -221,7 +219,7 @@ let is_sbits_typ ctx typ = | CT_sbits _ -> true | _ -> false -let ctor_bindings = List.fold_left (fun map (id, ctyp) -> Bindings.add id ctyp map) Bindings.empty +let ctor_bindings = List.fold_left (fun map (id, ctyp) -> UBindings.add id ctyp map) UBindings.empty (**************************************************************************) (* 3. Optimization of primitives and literals *) @@ -407,8 +405,6 @@ let analyze_primop' ctx id args typ = let v_one = V_lit (VL_int (Big_int.of_int 1), CT_fint 64) in let v_int n = V_lit (VL_int (Big_int.of_int n), CT_fint 64) in - c_debug (lazy ("Analyzing primop " ^ extern ^ "(" ^ Util.string_of_list ", " (fun aval -> Pretty_print_sail.to_string (pp_aval aval)) args ^ ")")); - match extern, args with | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> begin match cval_ctyp v1 with @@ -541,7 +537,6 @@ let analyze_primop' ctx id args typ = AE_val (AV_cval (V_lit (VL_bool false, CT_bool), typ)) | _, _ -> - c_debug (lazy ("No optimization routine found")); no_change let analyze_primop ctx id args typ = @@ -549,7 +544,6 @@ let analyze_primop ctx id args typ = if !optimize_primops then try analyze_primop' ctx id args typ with | Failure str -> - (c_debug (lazy ("Analyze primop failed for id " ^ string_of_id id ^ " reason: " ^ str))); no_change else no_change @@ -702,7 +696,6 @@ let hoist_id () = let hoist_allocations recursive_functions = function | CDEF_fundef (function_id, _, _, _) as cdef when IdSet.mem function_id recursive_functions -> - c_debug (lazy (Printf.sprintf "skipping recursive function %s" (string_of_id function_id))); [cdef] | CDEF_fundef (function_id, heap_return, args, body) -> @@ -1028,13 +1021,19 @@ let optimize recursive_functions cdefs = (**************************************************************************) let sgen_id id = Util.zencode_string (string_of_id id) +let sgen_uid uid = zencode_uid uid let sgen_name id = string_of_name id let codegen_id id = string (sgen_id id) +let codegen_uid id = string (sgen_uid id) let sgen_function_id id = let str = Util.zencode_string (string_of_id id) in !opt_prefix ^ String.sub str 1 (String.length str - 1) +let sgen_function_uid uid = + let str = zencode_uid uid in + !opt_prefix ^ String.sub str 1 (String.length str - 1) + let codegen_function_id id = string (sgen_function_id id) let rec sgen_ctyp = function @@ -1095,30 +1094,40 @@ let sgen_mask n = else failwith "Tried to create a mask literal for a vector greater than 64 bits." +let sgen_value = function + | VL_bits ([], _) -> "UINT64_C(0)" + | VL_bits (bs, true) -> "UINT64_C(" ^ Sail2_values.show_bitlist bs ^ ")" + | VL_bits (bs, false) -> "UINT64_C(" ^ Sail2_values.show_bitlist (List.rev bs) ^ ")" + | VL_int i -> Big_int.to_string i ^ "l" + | VL_bool true -> "true" + | VL_bool false -> "false" + | VL_null -> "NULL" + | VL_unit -> "UNIT" + | VL_bit Sail2_values.B0 -> "UINT64_C(0)" + | VL_bit Sail2_values.B1 -> "UINT64_C(1)" + | VL_bit Sail2_values.BU -> failwith "Undefined bit found in value" + | VL_real str -> str + | VL_string str -> "\"" ^ str ^ "\"" + let rec sgen_cval = function | V_id (id, ctyp) -> string_of_name id | V_ref (id, _) -> "&" ^ string_of_name id - | V_lit (vl, ctyp) -> string_of_value vl + | V_lit (vl, ctyp) -> sgen_value vl | V_call (op, cvals) -> sgen_call op cvals | V_field (f, field) -> - Printf.sprintf "%s.%s" (sgen_cval f) field + Printf.sprintf "%s.%s" (sgen_cval f) (sgen_uid field) | V_tuple_member (f, _, n) -> Printf.sprintf "%s.ztup%d" (sgen_cval f) n - | V_ctor_kind (f, ctor, [], _) -> - sgen_cval f ^ ".kind" - ^ " != Kind_" ^ Util.zencode_string (string_of_id ctor) | V_ctor_kind (f, ctor, unifiers, _) -> sgen_cval f ^ ".kind" - ^ " != Kind_" ^ Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) - | V_ctor_unwrap (ctor, f, [], _) -> - Printf.sprintf "%s.%s" (sgen_cval f) (Util.zencode_string (string_of_id ctor)) + ^ " != Kind_" ^ zencode_uid (ctor, unifiers) | V_struct (fields, _) -> Printf.sprintf "{%s}" - (Util.string_of_list ", " (fun (field, cval) -> string_of_id field ^ " = " ^ sgen_cval cval) fields) + (Util.string_of_list ", " (fun (field, cval) -> zencode_uid field ^ " = " ^ sgen_cval cval) fields) | V_ctor_unwrap (ctor, f, unifiers, _) -> Printf.sprintf "%s.%s" (sgen_cval f) - (Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)) + (sgen_uid (ctor, unifiers)) | V_poly (f, _) -> sgen_cval f and sgen_call op cvals = @@ -1299,7 +1308,7 @@ let rec sgen_clexp = function | CL_id (Current_exception _, _) -> "current_exception" | CL_id (Return _, _) -> assert false | CL_id (Name (id, _), _) -> "&" ^ sgen_id id - | CL_field (clexp, field) -> "&((" ^ sgen_clexp clexp ^ ")->" ^ Util.zencode_string field ^ ")" + | CL_field (clexp, field) -> "&((" ^ sgen_clexp clexp ^ ")->" ^ zencode_uid field ^ ")" | CL_tuple (clexp, n) -> "&((" ^ sgen_clexp clexp ^ ")->ztup" ^ string_of_int n ^ ")" | CL_addr clexp -> "(*(" ^ sgen_clexp clexp ^ "))" | CL_void -> assert false @@ -1310,7 +1319,7 @@ let rec sgen_clexp_pure = function | CL_id (Current_exception _, _) -> "current_exception" | CL_id (Return _, _) -> assert false | CL_id (Name (id, _), _) -> sgen_id id - | CL_field (clexp, field) -> sgen_clexp_pure clexp ^ "." ^ Util.zencode_string field + | CL_field (clexp, field) -> sgen_clexp_pure clexp ^ "." ^ zencode_uid field | CL_tuple (clexp, n) -> sgen_clexp_pure clexp ^ ".ztup" ^ string_of_int n | CL_addr clexp -> "(*(" ^ sgen_clexp_pure clexp ^ "))" | CL_void -> assert false @@ -1394,14 +1403,14 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = | I_funcall (x, extern, f, args) -> let c_args = Util.string_of_list ", " sgen_cval args in let ctyp = clexp_ctyp x in - let is_extern = Env.is_extern f ctx.tc_env "c" || extern in + let is_extern = Env.is_extern (fst f) ctx.tc_env "c" || extern in let fname = - if Env.is_extern f ctx.tc_env "c" then - Env.get_extern f ctx.tc_env "c" + if Env.is_extern (fst f) ctx.tc_env "c" then + Env.get_extern (fst f) ctx.tc_env "c" else if extern then - string_of_id f + string_of_id (fst f) else - sgen_function_id f + sgen_function_uid f in let fname = match fname, ctyp with @@ -1506,9 +1515,9 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = ^ Util.string_of_list ", " (fun x -> x) inits ^ " };"] @ prev | CT_struct (id, ctors) when is_stack_ctyp ctyp -> let gs = ngensym () in - let fold (inits, prev) (id, ctyp) = + let fold (inits, prev) (uid, ctyp) = let init, prev' = codegen_exn_return ctyp in - Printf.sprintf ".%s = %s" (sgen_id id) init :: inits, prev @ prev' + Printf.sprintf ".%s = %s" (sgen_uid uid) init :: inits, prev @ prev' in let inits, prev = List.fold_left fold ([], []) ctors in sgen_name gs, @@ -1560,36 +1569,34 @@ let codegen_type_def ctx = function | CTD_struct (id, ctors) -> let struct_ctyp = CT_struct (id, ctors) in - c_debug (lazy (Printf.sprintf "Generating struct for %s" (full_string_of_ctyp struct_ctyp))); - (* Generate a set_T function for every struct T *) let codegen_set (id, ctyp) = if is_stack_ctyp ctyp then - string (Printf.sprintf "rop->%s = op.%s;" (sgen_id id) (sgen_id id)) + string (Printf.sprintf "rop->%s = op.%s;" (sgen_uid id) (sgen_uid id)) else - string (Printf.sprintf "COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) + string (Printf.sprintf "COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_uid id) (sgen_uid id)) in let codegen_setter id ctors = string (let n = sgen_id id in Printf.sprintf "static void COPY(%s)(struct %s *rop, const struct %s op)" n n n) ^^ space ^^ surround 2 0 lbrace - (separate_map hardline codegen_set (Bindings.bindings ctors)) + (separate_map hardline codegen_set (UBindings.bindings ctors)) rbrace in (* Generate an init/clear_T function for every struct T *) let codegen_field_init f (id, ctyp) = if not (is_stack_ctyp ctyp) then - [string (Printf.sprintf "%s(%s)(&op->%s);" f (sgen_ctyp_name ctyp) (sgen_id id))] + [string (Printf.sprintf "%s(%s)(&op->%s);" f (sgen_ctyp_name ctyp) (sgen_uid id))] else [] in let codegen_init f id ctors = string (let n = sgen_id id in Printf.sprintf "static void %s(%s)(struct %s *op)" f n n) ^^ space ^^ surround 2 0 lbrace - (separate hardline (Bindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat)) + (separate hardline (UBindings.bindings ctors |> List.map (codegen_field_init f) |> List.concat)) rbrace in let codegen_eq = let codegen_eq_test (id, ctyp) = - string (Printf.sprintf "EQUAL(%s)(op1.%s, op2.%s)" (sgen_ctyp_name ctyp) (sgen_id id) (sgen_id id)) + string (Printf.sprintf "EQUAL(%s)(op1.%s, op2.%s)" (sgen_ctyp_name ctyp) (sgen_uid id) (sgen_uid id)) in string (Printf.sprintf "static bool EQUAL(%s)(struct %s op1, struct %s op2)" (sgen_id id) (sgen_id id) (sgen_id id)) ^^ space @@ -1601,7 +1608,7 @@ let codegen_type_def ctx = function in (* Generate the struct and add the generated functions *) let codegen_ctor (id, ctyp) = - string (sgen_ctyp ctyp) ^^ space ^^ codegen_id id + string (sgen_ctyp ctyp) ^^ space ^^ codegen_uid id in string (Printf.sprintf "// struct %s" (string_of_id id)) ^^ hardline ^^ string "struct" ^^ space ^^ codegen_id id ^^ space @@ -1623,17 +1630,17 @@ let codegen_type_def ctx = function | CTD_variant (id, tus) -> let codegen_tu (ctor_id, ctyp) = - separate space [string "struct"; lbrace; string (sgen_ctyp ctyp); codegen_id ctor_id ^^ semi; rbrace] + separate space [string "struct"; lbrace; string (sgen_ctyp ctyp); codegen_uid ctor_id ^^ semi; rbrace] in (* Create an if, else if, ... block that does something for each constructor *) let rec each_ctor v f = function | [] -> string "{}" | [(ctor_id, ctyp)] -> - string (Printf.sprintf "if (%skind == Kind_%s)" v (sgen_id ctor_id)) ^^ lbrace ^^ hardline + string (Printf.sprintf "if (%skind == Kind_%s)" v (sgen_uid ctor_id)) ^^ lbrace ^^ hardline ^^ jump 0 2 (f ctor_id ctyp) ^^ hardline ^^ rbrace | (ctor_id, ctyp) :: ctors -> - string (Printf.sprintf "if (%skind == Kind_%s) " v (sgen_id ctor_id)) ^^ lbrace ^^ hardline + string (Printf.sprintf "if (%skind == Kind_%s) " v (sgen_uid ctor_id)) ^^ lbrace ^^ hardline ^^ jump 0 2 (f ctor_id ctyp) ^^ hardline ^^ rbrace ^^ string " else " ^^ each_ctor v f ctors in @@ -1643,9 +1650,9 @@ let codegen_type_def ctx = function string (Printf.sprintf "static void CREATE(%s)(struct %s *op)" n n) ^^ hardline ^^ surround 2 0 lbrace - (string (Printf.sprintf "op->kind = Kind_%s;" (sgen_id ctor_id)) ^^ hardline + (string (Printf.sprintf "op->kind = Kind_%s;" (sgen_uid ctor_id)) ^^ hardline ^^ if not (is_stack_ctyp ctyp) then - string (Printf.sprintf "CREATE(%s)(&op->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) + string (Printf.sprintf "CREATE(%s)(&op->%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id)) else empty) rbrace in @@ -1657,7 +1664,7 @@ let codegen_type_def ctx = function if is_stack_ctyp ctyp then string (Printf.sprintf "/* do nothing */") else - string (Printf.sprintf "KILL(%s)(&%s->%s);" (sgen_ctyp_name ctyp) v (sgen_id ctor_id)) + string (Printf.sprintf "KILL(%s)(&%s->%s);" (sgen_ctyp_name ctyp) v (sgen_uid ctor_id)) in let codegen_clear = let n = sgen_id id in @@ -1676,16 +1683,16 @@ let codegen_type_def ctx = function in Printf.sprintf "%s op" (sgen_ctyp ctyp), empty, empty in - string (Printf.sprintf "static void %s(%sstruct %s *rop, %s)" (sgen_function_id ctor_id) (extra_params ()) (sgen_id id) ctor_args) ^^ hardline + string (Printf.sprintf "static void %s(%sstruct %s *rop, %s)" (sgen_function_uid ctor_id) (extra_params ()) (sgen_id id) ctor_args) ^^ hardline ^^ surround 2 0 lbrace (tuple ^^ each_ctor "rop->" (clear_field "rop") tus ^^ hardline - ^^ string ("rop->kind = Kind_" ^ sgen_id ctor_id) ^^ semi ^^ hardline + ^^ string ("rop->kind = Kind_" ^ sgen_uid ctor_id) ^^ semi ^^ hardline ^^ if is_stack_ctyp ctyp then - string (Printf.sprintf "rop->%s = op;" (sgen_id ctor_id)) + string (Printf.sprintf "rop->%s = op;" (sgen_uid ctor_id)) else - string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) ^^ hardline - ^^ string (Printf.sprintf "COPY(%s)(&rop->%s, op);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) ^^ hardline + string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id)) ^^ hardline + ^^ string (Printf.sprintf "COPY(%s)(&rop->%s, op);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id)) ^^ hardline ^^ tuple_cleanup) rbrace in @@ -1693,10 +1700,10 @@ let codegen_type_def ctx = function let n = sgen_id id in let set_field ctor_id ctyp = if is_stack_ctyp ctyp then - string (Printf.sprintf "rop->%s = op.%s;" (sgen_id ctor_id) (sgen_id ctor_id)) + string (Printf.sprintf "rop->%s = op.%s;" (sgen_uid ctor_id) (sgen_uid ctor_id)) else - string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id)) - ^^ string (Printf.sprintf " COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id)) + string (Printf.sprintf "CREATE(%s)(&rop->%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id)) + ^^ string (Printf.sprintf " COPY(%s)(&rop->%s, op.%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id) (sgen_uid ctor_id)) in string (Printf.sprintf "static void COPY(%s)(struct %s *rop, struct %s op)" n n n) ^^ hardline ^^ surround 2 0 lbrace @@ -1709,12 +1716,12 @@ let codegen_type_def ctx = function in let codegen_eq = let codegen_eq_test ctor_id ctyp = - string (Printf.sprintf "return EQUAL(%s)(op1.%s, op2.%s);" (sgen_ctyp_name ctyp) (sgen_id ctor_id) (sgen_id ctor_id)) + string (Printf.sprintf "return EQUAL(%s)(op1.%s, op2.%s);" (sgen_ctyp_name ctyp) (sgen_uid ctor_id) (sgen_uid ctor_id)) in let rec codegen_eq_tests = function | [] -> string "return false;" | (ctor_id, ctyp) :: ctors -> - string (Printf.sprintf "if (op1.kind == Kind_%s && op2.kind == Kind_%s) " (sgen_id ctor_id) (sgen_id ctor_id)) ^^ lbrace ^^ hardline + string (Printf.sprintf "if (op1.kind == Kind_%s && op2.kind == Kind_%s) " (sgen_uid ctor_id) (sgen_uid ctor_id)) ^^ lbrace ^^ hardline ^^ jump 0 2 (codegen_eq_test ctor_id ctyp) ^^ hardline ^^ rbrace ^^ string " else " ^^ codegen_eq_tests ctors in @@ -1726,7 +1733,7 @@ let codegen_type_def ctx = function ^^ string "enum" ^^ space ^^ string ("kind_" ^ sgen_id id) ^^ space ^^ separate space [ lbrace; - separate_map (comma ^^ space) (fun id -> string ("Kind_" ^ sgen_id id)) (List.map fst tus); + separate_map (comma ^^ space) (fun id -> string ("Kind_" ^ sgen_uid id)) (List.map fst tus); rbrace ^^ semi ] ^^ twice hardline ^^ string "struct" ^^ space ^^ codegen_id id ^^ space @@ -1784,12 +1791,12 @@ let codegen_tup ctx ctyps = empty else begin - let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, Bindings.add (mk_id ("tup" ^ string_of_int n)) ctyp fields) - (0, Bindings.empty) + let _, fields = List.fold_left (fun (n, fields) ctyp -> n + 1, UBindings.add (mk_id ("tup" ^ string_of_int n), []) ctyp fields) + (0, UBindings.empty) ctyps in generated := IdSet.add id !generated; - codegen_type_def ctx (CTD_struct (id, Bindings.bindings fields)) ^^ twice hardline + codegen_type_def ctx (CTD_struct (id, UBindings.bindings fields)) ^^ twice hardline end let codegen_node id ctyp = @@ -1997,15 +2004,12 @@ let codegen_def' ctx = function string (Printf.sprintf "%svoid %s(%s%s *rop, %s);" static (sgen_function_id id) (extra_params ()) (sgen_ctyp ret_ctyp) (Util.string_of_list ", " sgen_ctyp arg_ctyps)) | CDEF_fundef (id, ret_arg, args, instrs) as def -> - (* Extract type information about the function from the environment. *) - let quant, Typ_aux (fn_typ, _) = Env.get_val_spec id ctx.tc_env in - let arg_typs, ret_typ = match fn_typ with - | Typ_fn (arg_typs, ret_typ, _) -> arg_typs, ret_typ - | _ -> assert false + let arg_ctyps, ret_ctyp = match Bindings.find_opt id ctx.valspecs with + | Some vs -> vs + | None -> + c_error ~loc:(id_loc id) ("No valspec found for " ^ string_of_id id) in - let ctx' = { ctx with local_env = add_typquant (id_loc id) quant ctx.local_env } in - let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ in - + (* Check that the function has the correct arity at this point. *) if List.length arg_ctyps <> List.length args then c_error ~loc:(id_loc id) ("function arguments " @@ -2175,7 +2179,6 @@ let jib_of_ast env ast = let compile_ast env output_chan c_includes ast = try - c_debug (lazy (Util.log_line __MODULE__ __LINE__ "Identifying recursive functions")); let recursive_functions = Spec_analysis.top_sort_defs ast |> get_recursive_functions in let cdefs, ctx = jib_of_ast env ast in diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index 1a0411aa..b178f0e2 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -57,7 +57,6 @@ open Value2 open Anf -let opt_debug_flow_graphs = ref false let opt_memo_cache = ref false let optimize_aarch64_fast_struct = ref false @@ -138,7 +137,7 @@ let iblock1 = function | [instr] -> instr | instrs -> iblock instrs -let ctor_bindings = List.fold_left (fun map (id, ctyp) -> Bindings.add id ctyp map) Bindings.empty +let ctor_bindings = List.fold_left (fun map (uid, ctyp) -> UBindings.add uid ctyp map) UBindings.empty (** The context type contains two type-checking environments. ctx.local_env contains the closest typechecking @@ -148,9 +147,10 @@ let ctor_bindings = List.fold_left (fun map (id, ctyp) -> Bindings.add id ctyp m in ctx.locals, so we know when their type changes due to flow typing. *) type ctx = - { records : (ctyp Bindings.t) Bindings.t; + { records : (ctyp UBindings.t) Bindings.t; enums : IdSet.t Bindings.t; - variants : (ctyp Bindings.t) Bindings.t; + variants : (ctyp UBindings.t) Bindings.t; + valspecs : (ctyp list * ctyp) Bindings.t; tc_env : Env.t; local_env : Env.t; locals : (mut * ctyp) Bindings.t; @@ -168,6 +168,7 @@ let initial_ctx ~convert_typ:convert_typ ~optimize_anf:optimize_anf env = { records = Bindings.empty; enums = Bindings.empty; variants = Bindings.empty; + valspecs = Bindings.empty; tc_env = env; local_env = env; locals = Bindings.empty; @@ -269,7 +270,7 @@ let rec compile_aval l ctx = function let compile_fields (id, aval) = let field_setup, cval, field_cleanup = compile_aval l ctx aval in field_setup, - (id, cval), + ((id, []), cval), field_cleanup in let field_triples = List.map compile_fields (Bindings.bindings fields) in @@ -286,7 +287,7 @@ let rec compile_aval l ctx = function let compile_fields (id, aval) = let field_setup, cval, field_cleanup = compile_aval l ctx aval in field_setup - @ [icopy l (CL_field (CL_id (gs, ctyp), string_of_id id)) cval] + @ [icopy l (CL_field (CL_id (gs, ctyp), (id, []))) cval] @ field_cleanup in [idecl ctyp gs] @@ -302,7 +303,7 @@ let rec compile_aval l ctx = function | _ -> let gs = ngensym () in [idecl vector_ctyp gs; - iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init") [V_lit (VL_int Big_int.zero, CT_fint 64)]], + iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init", []) [V_lit (VL_int Big_int.zero, CT_fint 64)]], V_id (gs, vector_ctyp), [iclear vector_ctyp gs] end @@ -333,7 +334,7 @@ let rec compile_aval l ctx = function let gs = ngensym () in [iinit (CT_lbits true) gs (V_lit (first_chunk, CT_fbits (len mod 64, true)))] @ List.map (fun chunk -> ifuncall (CL_id (gs, CT_lbits true)) - (mk_id "append_64") + (mk_id "append_64", []) [V_id (gs, CT_lbits true); V_lit (chunk, CT_fbits (64, true))]) chunks, V_id (gs, CT_lbits true), [iclear (CT_lbits true) gs] @@ -381,12 +382,12 @@ let rec compile_aval l ctx = function let setup, cval, cleanup = compile_aval l ctx aval in setup @ [iextern (CL_id (gs, vector_ctyp)) - (mk_id "internal_vector_update") + (mk_id "internal_vector_update", []) [V_id (gs, vector_ctyp); V_lit (VL_int (Big_int.of_int i), CT_fint 64); cval]] @ cleanup in [idecl vector_ctyp gs; - iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init") [V_lit (VL_int (Big_int.of_int len), CT_fint 64)]] + iextern (CL_id (gs, vector_ctyp)) (mk_id "internal_vector_init", []) [V_lit (VL_int (Big_int.of_int len), CT_fint 64)]] @ List.concat (List.mapi aval_set (if direction then List.rev avals else avals)), V_id (gs, vector_ctyp), [iclear vector_ctyp gs] @@ -402,7 +403,7 @@ let rec compile_aval l ctx = function let gs = ngensym () in let mk_cons aval = let setup, cval, cleanup = compile_aval l ctx aval in - setup @ [ifuncall (CL_id (gs, CT_list ctyp)) (mk_id ("cons#" ^ string_of_ctyp ctyp)) [cval; V_id (gs, CT_list ctyp)]] @ cleanup + setup @ [ifuncall (CL_id (gs, CT_list ctyp)) (mk_id ("cons#" ^ string_of_ctyp ctyp), []) [cval; V_id (gs, CT_list ctyp)]] @ cleanup in [idecl (CT_list ctyp) gs] @ List.concat (List.map mk_cons (List.rev avals)), @@ -439,8 +440,8 @@ let optimize_call l ctx clexp id args arg_ctyps ret_ctyp = iclear ret_ctyp gs] @ !cleanup in - if not ctx.specialize_calls && Env.is_extern id ctx.tc_env "c" then - let extern = Env.get_extern id ctx.tc_env "c" in + if not ctx.specialize_calls && Env.is_extern (fst id) ctx.tc_env "c" then + let extern = Env.get_extern (fst id) ctx.tc_env "c" in begin match extern, List.map cval_ctyp args, clexp_ctyp clexp with | "slice", [CT_fbits _; CT_lint; _], CT_fbits (n, _) -> let start = ngensym () in @@ -487,7 +488,7 @@ let compile_funcall l ctx id args = List.rev !setup, begin fun clexp -> - iblock1 (optimize_call l ctx clexp id setup_args arg_ctyps ret_ctyp) + iblock1 (optimize_call l ctx clexp (id, []) setup_args arg_ctyps ret_ctyp) end, !cleanup @@ -547,7 +548,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = | AP_app (ctor, apat, variant_typ) -> begin match ctyp with | CT_variant (_, ctors) -> - let ctor_ctyp = Bindings.find ctor (ctor_bindings ctors) in + let ctor_ctyp = UBindings.find (ctor, []) (ctor_bindings ctors) in let pat_ctyp = apat_ctyp ctx apat in (* These should really be the same, something has gone wrong if they are not. *) if ctyp_equal ctor_ctyp (ctyp_of_typ ctx variant_typ) then @@ -569,7 +570,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = raise (Reporting.err_general l (Printf.sprintf "Variant constructor %s : %s matching against non-variant type %s : %s" (string_of_id ctor) (string_of_typ variant_typ) - (Jib_ir.string_of_cval cval) + (string_of_cval cval) (string_of_ctyp ctyp))) end @@ -716,14 +717,14 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = | AE_record_update (aval, fields, typ) -> let ctyp = ctyp_of_typ ctx typ in let ctors = match ctyp with - | CT_struct (_, ctors) -> List.fold_left (fun m (k, v) -> Bindings.add k v m) Bindings.empty ctors + | CT_struct (_, ctors) -> List.fold_left (fun m (k, v) -> UBindings.add k v m) UBindings.empty ctors | _ -> raise (Reporting.err_general l "Cannot perform record update for non-record type") in let gs = ngensym () in let compile_fields (id, aval) = let field_setup, cval, field_cleanup = compile_aval l ctx aval in field_setup - @ [icopy l (CL_field (CL_id (gs, ctyp), string_of_id id)) cval] + @ [icopy l (CL_field (CL_id (gs, ctyp), (id, []))) cval] @ field_cleanup in let setup, cval, cleanup = compile_aval l ctx aval in @@ -769,7 +770,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let compile_fields (field_id, aval) = let field_setup, cval, field_cleanup = compile_aval l ctx aval in field_setup - @ [icopy l (CL_field (CL_id (name id, ctyp_of_typ ctx typ), string_of_id field_id)) cval] + @ [icopy l (CL_field (CL_id (name id, ctyp_of_typ ctx typ), (field_id, []))) cval] @ field_cleanup in List.concat (List.map compile_fields (Bindings.bindings fields)), @@ -877,7 +878,7 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = let setup, cval, cleanup = compile_aval l ctx aval in let ctyp = match cval_ctyp cval with | CT_struct (struct_id, fields) -> - begin match Util.assoc_compare_opt Id.compare id fields with + begin match Util.assoc_compare_opt UId.compare (id, []) fields with | Some ctyp -> ctyp | None -> raise (Reporting.err_unreachable l __POS__ @@ -893,12 +894,8 @@ let rec compile_aexp ctx (AE_aux (aexp_aux, env, l)) = else [], ctyp in - let field_str = match unifiers with - | [] -> Util.zencode_string (string_of_id id) - | _ -> Util.zencode_string (string_of_id id ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) - in setup, - (fun clexp -> icopy l clexp (V_field (cval, field_str))), + (fun clexp -> icopy l clexp (V_field (cval, (id, unifiers)))), cleanup | AE_for (loop_var, loop_from, loop_to, loop_step, Ord_aux (ord, _), body) -> @@ -976,9 +973,9 @@ let compile_type_def ctx (TD_aux (type_def, (l, _))) = | TD_record (id, typq, ctors, _) -> let record_ctx = { ctx with local_env = add_typquant l typq ctx.local_env } in let ctors = - List.fold_left (fun ctors (typ, id) -> Bindings.add id (fast_int (ctyp_of_typ record_ctx typ)) ctors) Bindings.empty ctors + List.fold_left (fun ctors (typ, id) -> UBindings.add (id, []) (fast_int (ctyp_of_typ record_ctx typ)) ctors) UBindings.empty ctors in - CTD_struct (id, Bindings.bindings ctors), + CTD_struct (id, UBindings.bindings ctors), { ctx with records = Bindings.add id ctors ctx.records } | TD_variant (id, typq, tus, _) -> @@ -987,8 +984,8 @@ let compile_type_def ctx (TD_aux (type_def, (l, _))) = let ctx = { ctx with local_env = add_typquant (id_loc id) typq ctx.local_env } in ctyp_of_typ ctx typ, id in - let ctus = List.fold_left (fun ctus (ctyp, id) -> Bindings.add id ctyp ctus) Bindings.empty (List.map compile_tu tus) in - CTD_variant (id, Bindings.bindings ctus), + let ctus = List.fold_left (fun ctus (ctyp, id) -> UBindings.add (id, []) ctyp ctus) UBindings.empty (List.map compile_tu tus) in + CTD_variant (id, UBindings.bindings ctus), { ctx with variants = Bindings.add id ctus ctx.variants } (* Will be re-written before here, see bitfield.ml *) @@ -1048,7 +1045,7 @@ let fix_exception_block ?return:(return=None) ctx instrs = @ [igoto end_block_label] @ rewrite_exception (historic @ before) after | before, (I_aux (I_funcall (x, _, f, args), _) as funcall) :: after -> - let effects = match Env.get_val_spec f ctx.tc_env with + let effects = match Env.get_val_spec (fst f) ctx.tc_env with | _, Typ_aux (Typ_fn (_, _, effects), _) -> effects | exception (Type_error _) -> no_effect (* nullary union constructor, so no val spec *) | _ -> assert false (* valspec must have function type *) @@ -1242,24 +1239,11 @@ let compile_funcl ctx id pat guard exp = let instrs = fix_early_return (CL_id (return, ret_ctyp)) instrs in let instrs = fix_exception ~return:(Some ret_ctyp) ctx instrs in - if !opt_debug_flow_graphs then - begin - let instrs = Jib_optimize.(instrs |> optimize_unit |> flatten_instrs) in - let root, _, cfg = Jib_ssa.control_flow_graph instrs in - let idom = Jib_ssa.immediate_dominators cfg root in - let _, cfg = Jib_ssa.ssa instrs in - let out_chan = open_out (Util.zencode_string (string_of_id id) ^ ".gv") in - Jib_ssa.make_dot out_chan cfg; - close_out out_chan; - let out_chan = open_out (Util.zencode_string (string_of_id id) ^ ".dom.gv") in - Jib_ssa.make_dominators_dot out_chan idom cfg; - close_out out_chan; - end; - [CDEF_fundef (id, None, List.map fst compiled_args, instrs)], orig_ctx (** Compile a Sail toplevel definition into an IR definition **) let rec compile_def n total ctx def = + reset_anf_counter (); reset_gensym_counter (); match def with | DEF_fundef (FD_aux (FD_function (_, _, _, [FCL_aux (FCL_Funcl (id, _), _)]), _)) @@ -1314,7 +1298,8 @@ and compile_def' n total ctx = function in let ctx' = { ctx with local_env = add_typquant (id_loc id) quant ctx.local_env } in let arg_ctyps, ret_ctyp = List.map (ctyp_of_typ ctx') arg_typs, ctyp_of_typ ctx' ret_typ in - [CDEF_spec (id, arg_ctyps, ret_ctyp)], ctx + [CDEF_spec (id, arg_ctyps, ret_ctyp)], + { ctx with valspecs = Bindings.add id (arg_ctyps, ret_ctyp) ctx.valspecs } | DEF_fundef (FD_aux (FD_function (_, _, _, [FCL_aux (FCL_Funcl (id, Pat_aux (Pat_exp (pat, exp), _)), _)]), _)) -> Util.progress "Compiling " (string_of_id id) n total; @@ -1389,7 +1374,7 @@ and compile_def' n total ctx = function raise (Reporting.err_general Parse_ast.Unknown ("Could not compile:\n" ^ Pretty_print_sail.to_string (Pretty_print_sail.doc_def def))) let rec specialize_variants ctx prior = - let unifications = ref (Bindings.empty) in + let unifications = ref (UBindings.empty) in let fix_variant_ctyp var_id new_ctors = function | CT_variant (id, ctors) when Id.compare id var_id = 0 -> CT_variant (id, new_ctors) @@ -1403,12 +1388,14 @@ let rec specialize_variants ctx prior = (* specialize_constructor is called on all instructions when we find a constructor in a union type that is polymorphic. It's job is to record all uses of that constructor so we can monomorphise it. *) - let specialize_constructor ctx ctor_id ctyp = function - | I_aux (I_funcall (clexp, extern, id, [cval]), ((_, l) as aux)) as instr when Id.compare id ctor_id = 0 -> + let specialize_constructor ctx (ctor_id, existing_unifiers) ctyp = + assert (existing_unifiers = []); + function + | I_aux (I_funcall (clexp, extern, id, [cval]), ((_, l) as aux)) as instr when UId.compare id (ctor_id, []) = 0 -> (* Work out how each call to a constructor in instantiated and add that to unifications *) let unifiers = List.map ctyp_suprema (ctyp_unify ctyp (cval_ctyp cval)) in - let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> string_of_ctyp ctyp) unifiers) in - unifications := Bindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications; + let mono_id = (ctor_id, unifiers) in + unifications := UBindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications; (* We need to cast each cval to it's ctyp_suprema in order to put it in the most general constructor *) let setup, cval, cleanup = @@ -1432,13 +1419,12 @@ let rec specialize_variants ctx prior = mk_funcall (I_aux (I_funcall (clexp, extern, mono_id, [cval]), aux)) - | I_aux (I_funcall (clexp, extern, id, cvals), ((_, l) as aux)) as instr when Id.compare id ctor_id = 0 -> + | I_aux (I_funcall (clexp, extern, id, cvals), ((_, l) as aux)) as instr when UId.compare id (ctor_id, []) = 0 -> Reporting.unreachable l __POS__ "Multiple argument constructor found" - (* We have to be careful this is the only place where an F_ctor_kind can appear before calling specialize variants *) + (* We have to be careful this is the only place where an V_ctor_kind can appear before calling specialize variants *) | I_aux (I_jump (V_ctor_kind (_, id, unifiers, pat_ctyp), _), _) as instr when Id.compare id ctor_id = 0 -> - let mono_id = append_id ctor_id ("_" ^ Util.string_of_list "_" (fun ctyp -> string_of_ctyp ctyp) unifiers) in - unifications := Bindings.add mono_id (ctyp_suprema pat_ctyp) !unifications; + unifications := UBindings.add (ctor_id, unifiers) (ctyp_suprema pat_ctyp) !unifications; instr | instr -> instr @@ -1446,12 +1432,14 @@ let rec specialize_variants ctx prior = (* specialize_field performs the same job as specialize_constructor, but for struct fields rather than union constructors. *) - let specialize_field ctx field_id ctyp = function - | I_aux (I_copy (CL_field (clexp, field_str), cval), aux) when string_of_id field_id = field_str -> + let specialize_field ctx (field_id, existing_unifiers) ctyp = + assert (existing_unifiers = []); + function + | I_aux (I_copy (CL_field (clexp, field), cval), aux) when UId.compare (field_id, []) field = 0 -> let unifiers = List.map ctyp_suprema (ctyp_unify ctyp (cval_ctyp cval)) in - let mono_id = append_id field_id ("_" ^ Util.string_of_list "_" (fun ctyp -> string_of_ctyp ctyp) unifiers) in - unifications := Bindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications; - I_aux (I_copy (CL_field (clexp, string_of_id mono_id), cval), aux) + let mono_id = (field_id, unifiers) in + unifications := UBindings.add mono_id (ctyp_suprema (cval_ctyp cval)) !unifications; + I_aux (I_copy (CL_field (clexp, mono_id), cval), aux) | instr -> instr in @@ -1467,12 +1455,12 @@ let rec specialize_variants ctx prior = in let monomorphic_ctors = List.filter (fun (_, ctyp) -> not (is_polymorphic ctyp)) ctors in - let specialized_ctors = Bindings.bindings !unifications in + let specialized_ctors = UBindings.bindings !unifications in let new_ctors = monomorphic_ctors @ specialized_ctors in let ctx = { ctx with variants = Bindings.add var_id - (List.fold_left (fun m (id, ctyp) -> Bindings.add id ctyp m) !unifications monomorphic_ctors) + (List.fold_left (fun m (uid, ctyp) -> UBindings.add uid ctyp m) !unifications monomorphic_ctors) ctx.variants } in @@ -1491,12 +1479,12 @@ let rec specialize_variants ctx prior = in let monomorphic_fields = List.filter (fun (_, ctyp) -> not (is_polymorphic ctyp)) fields in - let specialized_fields = Bindings.bindings !unifications in + let specialized_fields = UBindings.bindings !unifications in let new_fields = monomorphic_fields @ specialized_fields in let ctx = { ctx with records = Bindings.add struct_id - (List.fold_left (fun m (id, ctyp) -> Bindings.add id ctyp m) !unifications monomorphic_fields) + (List.fold_left (fun m (uid, ctyp) -> UBindings.add uid ctyp m) !unifications monomorphic_fields) ctx.records } in diff --git a/src/jib/jib_compile.mli b/src/jib/jib_compile.mli index d4e67daa..273e9e03 100644 --- a/src/jib/jib_compile.mli +++ b/src/jib/jib_compile.mli @@ -56,10 +56,6 @@ open Ast_util open Jib open Type_check -(** Output a dataflow graph for each generated function in Graphviz - (dot) format. *) -val opt_debug_flow_graphs : bool ref - (** This forces all integer struct fields to be represented as int64_t. Specifically intended for the various TLB structs in the ARM v8.5 spec. *) @@ -73,9 +69,10 @@ val optimize_aarch64_fast_struct : bool ref well as a function that optimizes ANF expressions (which can just be the identity function) *) type ctx = - { records : (ctyp Bindings.t) Bindings.t; + { records : (ctyp Jib_util.UBindings.t) Bindings.t; enums : IdSet.t Bindings.t; - variants : (ctyp Bindings.t) Bindings.t; + variants : (ctyp Jib_util.UBindings.t) Bindings.t; + valspecs : (ctyp list * ctyp) Bindings.t; tc_env : Env.t; local_env : Env.t; locals : (mut * ctyp) Bindings.t; diff --git a/src/jib/jib_ir.ml b/src/jib/jib_ir.ml index 1449b070..df60db1c 100644 --- a/src/jib/jib_ir.ml +++ b/src/jib/jib_ir.ml @@ -70,49 +70,9 @@ let string_of_name = | Current_exception n -> "current_exception" ^ ssa_num n -let string_of_value = function - | VL_bits ([], _) -> "empty" - | VL_bits (bs, true) -> Sail2_values.show_bitlist bs - | VL_bits (bs, false) -> Sail2_values.show_bitlist (List.rev bs) - | VL_int i -> Big_int.to_string i - | VL_bool true -> "true" - | VL_bool false -> "false" - | VL_null -> "NULL" - | VL_unit -> "()" - | VL_bit Sail2_values.B0 -> "bitzero" - | VL_bit Sail2_values.B1 -> "bitone" - | VL_bit Sail2_values.BU -> "bitundef" - | VL_real str -> str - | VL_string str -> "\"" ^ str ^ "\"" - -let rec string_of_cval = function - | V_id (id, ctyp) -> string_of_name id - | V_ref (id, _) -> "&" ^ string_of_name id - | V_lit (vl, ctyp) -> string_of_value vl - | V_call (op, cvals) -> - Printf.sprintf "%s(%s)" (string_of_op op) (Util.string_of_list ", " string_of_cval cvals) - | V_field (f, field) -> - Printf.sprintf "%s.%s" (string_of_cval f) field - | V_tuple_member (f, _, n) -> - Printf.sprintf "%s.ztup%d" (string_of_cval f) n - | V_ctor_kind (f, ctor, [], _) -> - string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor) - | V_ctor_kind (f, ctor, unifiers, _) -> - string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) - | V_ctor_unwrap (ctor, f, [], _) -> - Printf.sprintf "%s as %s" (string_of_cval f) (string_of_id ctor) - | V_ctor_unwrap (ctor, f, unifiers, _) -> - Printf.sprintf "%s as %s" - (string_of_cval f) - (Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)) - | V_struct (fields, _) -> - Printf.sprintf "{%s}" - (Util.string_of_list ", " (fun (field, cval) -> zencode_id field ^ " = " ^ string_of_cval cval) fields) - | V_poly (f, _) -> string_of_cval f - let rec string_of_clexp = function | CL_id (id, ctyp) -> string_of_name id - | CL_field (clexp, field) -> string_of_clexp clexp ^ "." ^ field + | CL_field (clexp, field) -> string_of_clexp clexp ^ "." ^ string_of_uid field | CL_addr clexp -> string_of_clexp clexp ^ "*" | CL_tuple (clexp, n) -> string_of_clexp clexp ^ "." ^ string_of_int n | CL_void -> "void" @@ -159,9 +119,9 @@ module Ir_formatter = struct | I_copy (clexp, cval) -> add_instr n buf indent (string_of_clexp clexp ^ " = " ^ C.value cval) | I_funcall (clexp, false, id, args) -> - add_instr n buf indent (string_of_clexp clexp ^ " = " ^ zencode_id id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")") + add_instr n buf indent (string_of_clexp clexp ^ " = " ^ string_of_uid id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")") | I_funcall (clexp, true, id, args) -> - add_instr n buf indent (string_of_clexp clexp ^ " = $" ^ zencode_id id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")") + add_instr n buf indent (string_of_clexp clexp ^ " = $" ^ string_of_uid id ^ "(" ^ Util.string_of_list ", " C.value args ^ ")") | I_return cval -> add_instr n buf indent (C.keyword "return" ^ " " ^ C.value cval) | I_comment str -> @@ -185,6 +145,9 @@ module Ir_formatter = struct let id_ctyp (id, ctyp) = sprintf "%s: %s" (zencode_id id) (C.typ ctyp) + let uid_ctyp (uid, ctyp) = + sprintf "%s: %s" (string_of_uid uid) (C.typ ctyp) + let output_def buf = function | CDEF_reg_dec (id, ctyp, _) -> Buffer.add_string buf (sprintf "%s %s : %s" (C.keyword "register") (zencode_id id) (C.typ ctyp)) @@ -203,9 +166,9 @@ module Ir_formatter = struct | CDEF_type (CTD_enum (id, ids)) -> Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "enum") (zencode_id id) (Util.string_of_list ",\n " zencode_id ids)) | CDEF_type (CTD_struct (id, ids)) -> - Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "struct") (zencode_id id) (Util.string_of_list ",\n " id_ctyp ids)) + Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "struct") (zencode_id id) (Util.string_of_list ",\n " uid_ctyp ids)) | CDEF_type (CTD_variant (id, ids)) -> - Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "union") (zencode_id id) (Util.string_of_list ",\n " id_ctyp ids)) + Buffer.add_string buf (sprintf "%s %s {\n %s\n}" (C.keyword "union") (zencode_id id) (Util.string_of_list ",\n " uid_ctyp ids)) | CDEF_let (_, bindings, instrs) -> let instrs = C.modify_instrs instrs in let label_map = C.make_label_map instrs in @@ -220,8 +183,7 @@ module Ir_formatter = struct output_def buf def; Buffer.add_string buf "\n\n"; output_defs buf defs - | [] -> - Buffer.add_char buf '\n' + | [] -> () end end @@ -249,6 +211,7 @@ module Flat_ir_config : Ir_formatter.Config = struct let modify_instrs instrs = let open Jib_optimize in + reset_flat_counter (); instrs |> flatten_instrs |> remove_clear diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml index e7cb70da..323f3cd0 100644 --- a/src/jib/jib_optimize.ml +++ b/src/jib/jib_optimize.ml @@ -50,6 +50,7 @@ open Ast_util open Jib +open Jib_compile open Jib_util let optimize_unit instrs = @@ -81,6 +82,9 @@ let optimize_unit instrs = filter_instrs non_pointless_copy (map_instr_list unit_instr instrs) let flat_counter = ref 0 + +let reset_flat_counter () = flat_counter := 0 + let flat_id orig_id = let id = mk_id (string_of_name ~zencode:false orig_id ^ "_l#" ^ string_of_int !flat_counter) in incr flat_counter; @@ -295,8 +299,8 @@ let inline cdefs should_inline instrs = in let rec inline_instr = function - | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline function_id -> - begin match find_function function_id cdefs with + | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline (fst function_id) -> + begin match find_function (fst function_id) cdefs with | Some (None, ids, body) -> incr inlines; incr label_count; @@ -386,3 +390,155 @@ let rec remove_clear = function | I_aux (I_clear _, _) :: instrs -> remove_clear instrs | instr :: instrs -> instr :: remove_clear instrs | [] -> [] + +let remove_tuples cdefs ctx = + let already_removed = ref CTSet.empty in + let rec all_tuples = function + | CT_tup ctyps as ctyp -> + CTSet.add ctyp (List.fold_left CTSet.union CTSet.empty (List.map all_tuples ctyps)) + | CT_struct (_, id_ctyps) | CT_variant (_, id_ctyps) -> + List.fold_left (fun cts (_, ctyp) -> CTSet.union (all_tuples ctyp) cts) CTSet.empty id_ctyps + | CT_list ctyp | CT_vector (_, ctyp) | CT_ref ctyp -> + all_tuples ctyp + | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ + | CT_unit | CT_bool | CT_real | CT_bit | CT_poly | CT_string | CT_enum _ -> + CTSet.empty + in + let rec tuple_depth = function + | CT_tup ctyps as ctyp -> + 1 + List.fold_left (fun d ctyp -> max d (tuple_depth ctyp)) 0 ctyps + | CT_struct (_, id_ctyps) | CT_variant (_, id_ctyps) -> + List.fold_left (fun d (_, ctyp) -> max (tuple_depth ctyp) d) 0 id_ctyps + | CT_list ctyp | CT_vector (_, ctyp) | CT_ref ctyp -> + tuple_depth ctyp + | CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ + | CT_unit | CT_bool | CT_real | CT_bit | CT_poly | CT_string | CT_enum _ -> + 0 + in + let rec fix_tuples = function + | CT_tup ctyps -> + let ctyps = List.map fix_tuples ctyps in + let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in + CT_struct (mk_id name, List.mapi (fun n ctyp -> (mk_id (name ^ string_of_int n), []), ctyp) ctyps) + | CT_struct (id, id_ctyps) -> + CT_struct (id, List.map (fun (id, ctyp) -> id, fix_tuples ctyp) id_ctyps) + | CT_variant (id, id_ctyps) -> + CT_variant (id, List.map (fun (id, ctyp) -> id, fix_tuples ctyp) id_ctyps) + | CT_list ctyp -> CT_list (fix_tuples ctyp) + | CT_vector (d, ctyp) -> CT_vector (d, fix_tuples ctyp) + | CT_ref ctyp -> CT_ref (fix_tuples ctyp) + | (CT_lint | CT_fint _ | CT_lbits _ | CT_sbits _ | CT_fbits _ | CT_constant _ + | CT_unit | CT_bool | CT_real | CT_bit | CT_poly | CT_string | CT_enum _) as ctyp -> + ctyp + in + let rec fix_cval = function + | V_id (id, ctyp) -> V_id (id, ctyp) + | V_ref (id, ctyp) -> V_ref (id, ctyp) + | V_lit (vl, ctyp) -> V_lit (vl, ctyp) + | V_ctor_kind (cval, id, unifiers, ctyp) -> + V_ctor_kind (fix_cval cval, id, unifiers, ctyp) + | V_ctor_unwrap (id, cval, unifiers, ctyp) -> + V_ctor_unwrap (id, fix_cval cval, unifiers, ctyp) + | V_tuple_member (cval, _, n) -> + let ctyp = fix_tuples (cval_ctyp cval) in + let cval = fix_cval cval in + let field = match ctyp with + | CT_struct (id, _) -> + mk_id (string_of_id id ^ string_of_int n) + | _ -> assert false + in + V_field (cval, (field, [])) + | V_call (op, cvals) -> + V_call (op, List.map (fix_cval) cvals) + | V_field (cval, field) -> + V_field (fix_cval cval, field) + | V_struct (fields, ctyp) -> V_struct (List.map (fun (id, cval) -> id, fix_cval cval) fields, ctyp) + | V_poly (cval, ctyp) -> V_poly (fix_cval cval, ctyp) + in + let rec fix_clexp = function + | CL_id (id, ctyp) -> CL_id (id, ctyp) + | CL_addr clexp -> CL_addr (fix_clexp clexp) + | CL_tuple (clexp, n) -> + let ctyp = fix_tuples (clexp_ctyp clexp) in + let clexp = fix_clexp clexp in + let field = match ctyp with + | CT_struct (id, _) -> + mk_id (string_of_id id ^ string_of_int n) + | _ -> assert false + in + CL_field (clexp, (field, [])) + | CL_field (clexp, field) -> CL_field (fix_clexp clexp, field) + | CL_void -> CL_void + | CL_rmw (read, write, ctyp) -> CL_rmw (read, write, ctyp) + in + let rec fix_instr_aux = function + | I_funcall (clexp, extern, id, args) -> + I_funcall (fix_clexp clexp, extern, id, List.map fix_cval args) + | I_copy (clexp, cval) -> I_copy (fix_clexp clexp, fix_cval cval) + | I_init (ctyp, id, cval) -> I_init (ctyp, id, fix_cval cval) + | I_reinit (ctyp, id, cval) -> I_reinit (ctyp, id, fix_cval cval) + | I_jump (cval, label) -> I_jump (fix_cval cval, label) + | I_throw cval -> I_throw (fix_cval cval) + | I_return cval -> I_return (fix_cval cval) + | I_if (cval, then_instrs, else_instrs, ctyp) -> + I_if (fix_cval cval, List.map fix_instr then_instrs, List.map fix_instr else_instrs, ctyp) + | I_block instrs -> I_block (List.map fix_instr instrs) + | I_try_block instrs -> I_try_block (List.map fix_instr instrs) + | (I_goto _ | I_label _ | I_decl _ | I_clear _ | I_end _ | I_comment _ + | I_reset _ | I_undefined _ | I_match_failure | I_raw _) as instr -> instr + and fix_instr (I_aux (instr, aux)) = I_aux (fix_instr_aux instr, aux) + in + let fix_conversions = function + | I_aux (I_copy (clexp, cval), ((_, l) as aux)) as instr -> + begin match clexp_ctyp clexp, cval_ctyp cval with + | CT_tup lhs_ctyps, CT_tup rhs_ctyps when List.length lhs_ctyps = List.length rhs_ctyps -> + let elems = List.length lhs_ctyps in + if List.for_all2 ctyp_equal lhs_ctyps rhs_ctyps then + [instr] + else + List.mapi (fun n _ -> icopy l (CL_tuple (clexp, n)) (V_tuple_member (cval, elems, n))) lhs_ctyps + | _ -> [instr] + end + | instr -> [instr] + in + let fix_ctx ctx = + { ctx with + records = Bindings.map (UBindings.map fix_tuples) ctx.records; + variants = Bindings.map (UBindings.map fix_tuples) ctx.variants; + valspecs = Bindings.map (fun (ctyps, ctyp) -> List.map fix_tuples ctyps, fix_tuples ctyp) ctx.valspecs; + locals = Bindings.map (fun (mut, ctyp) -> mut, fix_tuples ctyp) ctx.locals + } + in + let to_struct = function + | CT_tup ctyps -> + let ctyps = List.map fix_tuples ctyps in + let name = "tuple#" ^ Util.string_of_list "_" string_of_ctyp ctyps in + CDEF_type (CTD_struct (mk_id name, List.mapi (fun n ctyp -> (mk_id (name ^ string_of_int n), []), ctyp) ctyps)) + | _ -> assert false + in + let rec go acc = function + | cdef :: cdefs -> + let tuples = CTSet.fold (fun ctyp -> CTSet.union (all_tuples ctyp)) (cdef_ctyps cdef) CTSet.empty in + let tuples = CTSet.diff tuples !already_removed in + (* In the case where we have ((x, y), z) and (x, y) we need to + generate (x, y) first, so we sort by the depth of nesting in + the tuples (note we build acc in reverse order) *) + let sorted_tuples = + CTSet.elements tuples + |> List.map (fun ctyp -> tuple_depth ctyp, ctyp) + |> List.sort (fun (d1, _) (d2, _) -> compare d2 d1) + |> List.map snd + in + let structs = List.map to_struct sorted_tuples in + already_removed := CTSet.union tuples !already_removed; + let cdef = + cdef + |> cdef_concatmap_instr fix_conversions + |> cdef_map_instr fix_instr + |> cdef_map_ctyp fix_tuples + in + go (cdef :: structs @ acc) cdefs + | [] -> List.rev acc + in + go [] cdefs, + fix_ctx ctx diff --git a/src/jib/jib_optimize.mli b/src/jib/jib_optimize.mli index a69b45b7..7dae53a9 100644 --- a/src/jib/jib_optimize.mli +++ b/src/jib/jib_optimize.mli @@ -60,6 +60,7 @@ val optimize_unit : instr list -> instr list instructions, prodcing a flat list of instructions. *) val flatten_instrs : instr list -> instr list val flatten_cdef : cdef -> cdef +val reset_flat_counter : unit -> unit val unique_per_function_ids : cdef list -> cdef list @@ -75,3 +76,5 @@ val remove_unused_labels : instr list -> instr list val remove_dead_after_goto : instr list -> instr list val remove_dead_code : instr list -> instr list + +val remove_tuples : cdef list -> Jib_compile.ctx -> cdef list * Jib_compile.ctx diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index 7a827ece..e9ec8260 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -62,7 +62,9 @@ module IntMap = Map.Make(struct type t = int let compare = compare end) let zencode_upper_id id = Util.zencode_upper_string (string_of_id id) let zencode_id id = Util.zencode_string (string_of_id id) let zencode_name id = string_of_name ~deref_current_exception:false ~zencode:true id - +let zencode_uid (id, ctyps) = + Util.zencode_string (string_of_id id ^ "#" ^ Util.string_of_list "_" string_of_ctyp ctyps) + let opt_ignore_overflow = ref false let opt_auto = ref false @@ -158,9 +160,9 @@ let rec smt_ctyp ctx = function | CT_enum (id, elems) -> mk_enum (zencode_upper_id id) (List.map zencode_id elems) | CT_struct (id, fields) -> - mk_record (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) fields) + mk_record (zencode_upper_id id) (List.map (fun (uid, ctyp) -> (zencode_uid uid, smt_ctyp ctx ctyp)) fields) | CT_variant (id, ctors) -> - mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) ctors) + mk_variant (zencode_upper_id id) (List.map (fun (uid, ctyp) -> (zencode_uid uid, smt_ctyp ctx ctyp)) ctors) | CT_tup ctyps -> ctx.tuple_sizes := IntSet.add (List.length ctyps) !(ctx.tuple_sizes); Tuple (List.map (smt_ctyp ctx) ctyps) @@ -293,14 +295,14 @@ let rec smt_cval ctx cval = | V_field (union, field) -> begin match cval_ctyp union with | CT_struct (struct_id, _) -> - Fn (zencode_upper_id struct_id ^ "_" ^ field, [smt_cval ctx union]) + Fn (zencode_upper_id struct_id ^ "_" ^ zencode_uid field, [smt_cval ctx union]) | _ -> failwith "Field for non-struct type" end | V_struct (fields, ctyp) -> begin match ctyp with | CT_struct (struct_id, field_ctyps) -> let set_field (field, cval) = - match Util.assoc_compare_opt Id.compare field field_ctyps with + match Util.assoc_compare_opt UId.compare field field_ctyps with | None -> failwith "Field type not found" | Some ctyp when ctyp_equal (cval_ctyp cval) ctyp -> smt_cval ctx cval @@ -324,7 +326,7 @@ let rec smt_cval ctx cval = end | _ -> assert false end - | cval -> failwith ("Unrecognised cval " ^ Jib_ir.string_of_cval cval) + | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval) let add_event ctx ev smt = let stack = event_stack ctx ev in @@ -1184,12 +1186,12 @@ let smt_ctype_def ctx = function | CTD_struct (id, fields) -> [declare_datatypes (mk_record (zencode_upper_id id) - (List.map (fun (field, ctyp) -> zencode_upper_id id ^ "_" ^ zencode_id field, smt_ctyp ctx ctyp) fields))] + (List.map (fun (field, ctyp) -> zencode_upper_id id ^ "_" ^ zencode_uid field, smt_ctyp ctx ctyp) fields))] | CTD_variant (id, ctors) -> [declare_datatypes (mk_variant (zencode_upper_id id) - (List.map (fun (ctor, ctyp) -> zencode_id ctor, smt_ctyp ctx ctyp) ctors))] + (List.map (fun (ctor, ctyp) -> zencode_uid ctor, smt_ctyp ctx ctyp) ctors))] let rec generate_ctype_defs ctx = function | CDEF_type ctd :: cdefs -> smt_ctype_def ctx ctd :: generate_ctype_defs ctx cdefs @@ -1274,8 +1276,8 @@ let rec ctyp_of_typ ctx typ = | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> CT_ref (ctyp_of_typ ctx typ) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> Bindings.bindings) - | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> Bindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> UBindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> UBindings.bindings) | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) @@ -1524,10 +1526,10 @@ let rmw_modify smt = function begin match ctyp with | CT_struct (struct_id, fields) -> let set_field (field', _) = - if Util.zencode_string field = zencode_id field' then + if UId.compare field field' = 0 then smt else - Fn (zencode_upper_id struct_id ^ "_" ^ zencode_id field', [Var (rmw_read clexp)]) + Fn (zencode_upper_id struct_id ^ "_" ^ zencode_uid field', [Var (rmw_read clexp)]) in Fn (zencode_upper_id struct_id, List.map set_field fields) | _ -> @@ -1557,8 +1559,8 @@ let smt_instr ctx = let open Type_check in function | I_aux (I_funcall (CL_id (id, ret_ctyp), extern, function_id, args), (_, l)) -> - if Env.is_extern function_id ctx.tc_env "c" && not extern then - let name = Env.get_extern function_id ctx.tc_env "c" in + if Env.is_extern (fst function_id) ctx.tc_env "c" && not extern then + let name = Env.get_extern (fst function_id) ctx.tc_env "c" in if name = "sqrt_real" then begin match args with | [v] -> builtin_sqrt_real ctx (zencode_name id) v @@ -1609,9 +1611,9 @@ let smt_instr ctx = else let value = smt_builtin ctx name args ret_ctyp in [define_const ctx id ret_ctyp value] - else if extern && string_of_id function_id = "internal_vector_init" then + else if extern && string_of_id (fst function_id) = "internal_vector_init" then [declare_const ctx id ret_ctyp] - else if extern && string_of_id function_id = "internal_vector_update" then + else if extern && string_of_id (fst function_id) = "internal_vector_update" then begin match args with | [vec; i; x] -> let sz = int_size ctx (cval_ctyp i) in @@ -1620,7 +1622,7 @@ let smt_instr ctx = | _ -> Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update" end - else if string_of_id function_id = "sail_assert" then + else if string_of_id (fst function_id) = "sail_assert" then begin match args with | [assertion; _] -> let smt = smt_cval ctx assertion in @@ -1629,7 +1631,7 @@ let smt_instr ctx = | _ -> Reporting.unreachable l __POS__ "Bad arguments for assertion" end - else if string_of_id function_id = "sail_assume" then + else if string_of_id (fst function_id) = "sail_assume" then begin match args with | [assumption] -> let smt = smt_cval ctx assumption in @@ -1640,9 +1642,9 @@ let smt_instr ctx = end else if not extern then let smt_args = List.map (smt_cval ctx) args in - [define_const ctx id ret_ctyp (Ctor (zencode_id function_id, smt_args))] + [define_const ctx id ret_ctyp (Ctor (zencode_id (fst function_id), smt_args))] else - failwith ("Unrecognised function " ^ string_of_id function_id) + failwith ("Unrecognised function " ^ string_of_id (fst function_id)) | I_aux (I_copy (CL_addr (CL_id (_, _)), _), (_, l)) -> Reporting.unreachable l __POS__ "Register reference write should be re-written by now" @@ -1850,7 +1852,7 @@ let smt_header ctx cdefs = let expand_reg_deref env register_map = function | I_aux (I_funcall (clexp, false, function_id, [reg_ref]), (_, l)) as instr -> let open Type_check in - begin match (if Env.is_extern function_id env "smt" then Some (Env.get_extern function_id env "smt") else None) with + begin match (if Env.is_extern (fst function_id) env "smt" then Some (Env.get_extern (fst function_id) env "smt") else None) with | Some "reg_deref" -> begin match cval_ctyp reg_ref with | CT_ref reg_ctyp -> diff --git a/src/jib/jib_smt_fuzz.ml b/src/jib/jib_smt_fuzz.ml index 28ec40b9..846d0178 100644 --- a/src/jib/jib_smt_fuzz.ml +++ b/src/jib/jib_smt_fuzz.ml @@ -187,7 +187,7 @@ let fuzz_cdef ctx all_cdefs = function let jib = let gs = ngensym () in - [ifuncall (CL_id (gs, ret_ctyp)) id (List.map snd values)] + [ifuncall (CL_id (gs, ret_ctyp)) (id, []) (List.map snd values)] @ gen_assertion ret_ctyp result gs @ [iend ()] in diff --git a/src/jib/jib_ssa.ml b/src/jib/jib_ssa.ml index 840dea97..ba4337b0 100644 --- a/src/jib/jib_ssa.ml +++ b/src/jib/jib_ssa.ml @@ -722,7 +722,7 @@ let string_of_ssainstr = function | Phi (id, ctyp, args) -> string_of_name id ^ " : " ^ string_of_ctyp ctyp ^ " = φ(" ^ Util.string_of_list ", " string_of_name args ^ ")" | Pi cvals -> - "π(" ^ Util.string_of_list ", " (fun v -> String.escaped (Jib_ir.string_of_cval v)) cvals ^ ")" + "π(" ^ Util.string_of_list ", " (fun v -> String.escaped (string_of_cval v)) cvals ^ ")" let string_of_phis = function | [] -> "" diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml index a1dac297..caef7ecb 100644 --- a/src/jib/jib_util.ml +++ b/src/jib/jib_util.ml @@ -257,20 +257,6 @@ let rec instr_rename from_id to_id (I_aux (instr, aux)) = (* 1. Instruction pretty printer *) (**************************************************************************) -let string_of_value = function - | VL_bits ([], _) -> "UINT64_C(0)" - | VL_bits (bs, true) -> "UINT64_C(" ^ Sail2_values.show_bitlist bs ^ ")" - | VL_bits (bs, false) -> "UINT64_C(" ^ Sail2_values.show_bitlist (List.rev bs) ^ ")" - | VL_int i -> Big_int.to_string i ^ "l" - | VL_bool true -> "true" - | VL_bool false -> "false" - | VL_null -> "NULL" - | VL_unit -> "UNIT" - | VL_bit Sail2_values.B0 -> "UINT64_C(0)" - | VL_bit Sail2_values.B1 -> "UINT64_C(1)" - | VL_bit Sail2_values.BU -> failwith "Undefined bit found in value" - | VL_real str -> str - | VL_string str -> "\"" ^ str ^ "\"" let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) = let ssa_num n = if n = -1 then "" else ("/" ^ string_of_int n) in @@ -341,6 +327,11 @@ let rec string_of_ctyp = function | CT_ref ctyp -> "&(" ^ string_of_ctyp ctyp ^ ")" | CT_poly -> "*" +and string_of_uid (id, ctyps) = + match ctyps with + | [] -> Util.zencode_string (string_of_id id) + | _ -> Util.zencode_string (string_of_id id ^ "#" ^ Util.string_of_list "_" string_of_ctyp ctyps) + (** This function is like string_of_ctyp, but recursively prints all constructors in variants and structs. Used for debug output. *) and full_string_of_ctyp = function @@ -348,12 +339,12 @@ and full_string_of_ctyp = function | CT_struct (id, ctors) -> "struct " ^ string_of_id id ^ "{" - ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors + ^ Util.string_of_list ", " (fun ((id, _), ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors ^ "}" | CT_variant (id, ctors) -> "union " ^ string_of_id id ^ "{" - ^ Util.string_of_list ", " (fun (id, ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors + ^ Util.string_of_list ", " (fun ((id, _), ctyp) -> string_of_id id ^ " : " ^ full_string_of_ctyp ctyp) ctors ^ "}" | CT_vector (true, ctyp) -> "vector(dec, " ^ full_string_of_ctyp ctyp ^ ")" | CT_vector (false, ctyp) -> "vector(inc, " ^ full_string_of_ctyp ctyp ^ ")" @@ -361,6 +352,46 @@ and full_string_of_ctyp = function | CT_ref ctyp -> "ref(" ^ full_string_of_ctyp ctyp ^ ")" | ctyp -> string_of_ctyp ctyp +let string_of_value = function + | VL_bits ([], _) -> "empty" + | VL_bits (bs, true) -> Sail2_values.show_bitlist bs + | VL_bits (bs, false) -> Sail2_values.show_bitlist (List.rev bs) + | VL_int i -> Big_int.to_string i + | VL_bool true -> "true" + | VL_bool false -> "false" + | VL_null -> "NULL" + | VL_unit -> "()" + | VL_bit Sail2_values.B0 -> "bitzero" + | VL_bit Sail2_values.B1 -> "bitone" + | VL_bit Sail2_values.BU -> "bitundef" + | VL_real str -> str + | VL_string str -> "\"" ^ str ^ "\"" + +let rec string_of_cval = function + | V_id (id, ctyp) -> string_of_name id + | V_ref (id, _) -> "&" ^ string_of_name id + | V_lit (vl, ctyp) -> string_of_value vl + | V_call (op, cvals) -> + Printf.sprintf "%s(%s)" (string_of_op op) (Util.string_of_list ", " string_of_cval cvals) + | V_field (f, field) -> + Printf.sprintf "%s.%s" (string_of_cval f) (string_of_uid field) + | V_tuple_member (f, _, n) -> + Printf.sprintf "%s.ztup%d" (string_of_cval f) n + | V_ctor_kind (f, ctor, [], _) -> + string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor) + | V_ctor_kind (f, ctor, unifiers, _) -> + string_of_cval f ^ " is " ^ Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) + | V_ctor_unwrap (ctor, f, [], _) -> + Printf.sprintf "%s as %s" (string_of_cval f) (string_of_id ctor) + | V_ctor_unwrap (ctor, f, unifiers, _) -> + Printf.sprintf "%s as %s" + (string_of_cval f) + (Util.zencode_string (string_of_id ctor ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)) + | V_struct (fields, _) -> + Printf.sprintf "{%s}" + (Util.string_of_list ", " (fun (field, cval) -> string_of_uid field ^ " = " ^ string_of_cval cval) fields) + | V_poly (f, _) -> string_of_cval f + let rec map_ctyp f = function | (CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string | CT_poly | CT_enum _) as ctyp -> f ctyp @@ -368,8 +399,10 @@ let rec map_ctyp f = function | CT_ref ctyp -> f (CT_ref (map_ctyp f ctyp)) | CT_vector (direction, ctyp) -> f (CT_vector (direction, map_ctyp f ctyp)) | CT_list ctyp -> f (CT_list (map_ctyp f ctyp)) - | CT_struct (id, ctors) -> f (CT_struct (id, List.map (fun (id, ctyp) -> id, map_ctyp f ctyp) ctors)) - | CT_variant (id, ctors) -> f (CT_variant (id, List.map (fun (id, ctyp) -> id, map_ctyp f ctyp) ctors)) + | CT_struct (id, ctors) -> + f (CT_struct (id, List.map (fun ((id, ctyps), ctyp) -> (id, List.map (map_ctyp f) ctyps), map_ctyp f ctyp) ctors)) + | CT_variant (id, ctors) -> + f (CT_variant (id, List.map (fun ((id, ctyps), ctyp) -> (id, List.map (map_ctyp f) ctyps), map_ctyp f ctyp) ctors)) let rec ctyp_equal ctyp1 ctyp2 = match ctyp1, ctyp2 with @@ -469,6 +502,23 @@ end module CTSet = Set.Make(CT) module CTMap = Map.Make(CT) +module UId = struct + type t = uid + let lex_ord c1 c2 = if c1 = 0 then c2 else c1 + let rec compare_ctyps ctyps1 ctyps2 = + match ctyps1, ctyps2 with + | (ctyp1 :: ctyps1), (ctyp2 :: ctyps2) -> + lex_ord (CT.compare ctyp1 ctyp2) (compare_ctyps ctyps1 ctyps2) + | [], [] -> 0 + | [], _ -> 1 + | _, [] -> -1 + let compare (id1, ctyps1) (id2, ctyps2) = + let lex_ord c1 c2 = if c1 = 0 then c2 else c1 in + lex_ord (Id.compare id1 id2) (compare_ctyps ctyps1 ctyps2) +end + +module UBindings = Map.Make(UId) + let rec ctyp_unify ctyp1 ctyp2 = match ctyp1, ctyp2 with | CT_tup ctyps1, CT_tup ctyps2 when List.length ctyps1 = List.length ctyps2 -> @@ -609,7 +659,7 @@ let instr_typed_writes (I_aux (aux, _)) = let rec map_clexp_ctyp f = function | CL_id (id, ctyp) -> CL_id (id, f ctyp) | CL_rmw (read, write, ctyp) -> CL_rmw (read, write, f ctyp) - | CL_field (clexp, field) -> CL_field (map_clexp_ctyp f clexp, field) + | CL_field (clexp, (id, ctyps)) -> CL_field (map_clexp_ctyp f clexp, (id, List.map f ctyps)) | CL_tuple (clexp, n) -> CL_tuple (map_clexp_ctyp f clexp, n) | CL_addr clexp -> CL_addr (map_clexp_ctyp f clexp) | CL_void -> CL_void @@ -626,9 +676,10 @@ let rec map_cval_ctyp f = function V_tuple_member (map_cval_ctyp f cval, i, j) | V_call (op, cvals) -> V_call (op, List.map (map_cval_ctyp f) cvals) - | V_field (cval, field) -> - V_field (map_cval_ctyp f cval, field) - | V_struct (fields, ctyp) -> V_struct (List.map (fun (id, cval) -> id, map_cval_ctyp f cval) fields, f ctyp) + | V_field (cval, (id, ctyps)) -> + V_field (map_cval_ctyp f cval, (id, List.map f ctyps)) + | V_struct (fields, ctyp) -> + V_struct (List.map (fun ((id, ctyps), cval) -> (id, List.map f ctyps), map_cval_ctyp f cval) fields, f ctyp) | V_poly (cval, ctyp) -> V_poly (map_cval_ctyp f cval, f ctyp) let rec map_instr_ctyp f (I_aux (instr, aux)) = @@ -638,8 +689,8 @@ let rec map_instr_ctyp f (I_aux (instr, aux)) = | I_if (cval, then_instrs, else_instrs, ctyp) -> I_if (map_cval_ctyp f cval, List.map (map_instr_ctyp f) then_instrs, List.map (map_instr_ctyp f) else_instrs, f ctyp) | I_jump (cval, label) -> I_jump (map_cval_ctyp f cval, label) - | I_funcall (clexp, extern, id, cvals) -> - I_funcall (map_clexp_ctyp f clexp, extern, id, List.map (map_cval_ctyp f) cvals) + | I_funcall (clexp, extern, (id, ctyps), cvals) -> + I_funcall (map_clexp_ctyp f clexp, extern, (id, List.map f ctyps), List.map (map_cval_ctyp f) cvals) | I_copy (clexp, cval) -> I_copy (map_clexp_ctyp f clexp, map_cval_ctyp f cval) | I_clear (ctyp, id) -> I_clear (f ctyp, id) | I_return cval -> I_return (map_cval_ctyp f cval) @@ -669,6 +720,21 @@ let rec map_instr f (I_aux (instr, aux)) = in f (I_aux (instr, aux)) +(** Map over each instruction within an instruction, bottom-up *) +let rec concatmap_instr f (I_aux (instr, aux)) = + let instr = match instr with + | I_decl _ | I_init _ | I_reset _ | I_reinit _ + | I_funcall _ | I_copy _ | I_clear _ | I_jump _ | I_throw _ | I_return _ + | I_comment _ | I_label _ | I_goto _ | I_raw _ | I_match_failure | I_undefined _ | I_end _ -> instr + | I_if (cval, instrs1, instrs2, ctyp) -> + I_if (cval, List.concat (List.map (concatmap_instr f) instrs1), List.concat (List.map (concatmap_instr f) instrs2), ctyp) + | I_block instrs -> + I_block (List.concat (List.map (concatmap_instr f) instrs)) + | I_try_block instrs -> + I_try_block (List.concat (List.map (concatmap_instr f) instrs)) + in + f (I_aux (instr, aux)) + (** Iterate over each instruction within an instruction, bottom-up *) let rec iter_instr f (I_aux (instr, aux)) = match instr with @@ -691,10 +757,25 @@ let cdef_map_instr f = function | CDEF_spec (id, ctyps, ctyp) -> CDEF_spec (id, ctyps, ctyp) | CDEF_type tdef -> CDEF_type tdef +(** Map over each instruction in a cdef using concatmap_instr *) +let cdef_concatmap_instr f = function + | CDEF_reg_dec (id, ctyp, instrs) -> + CDEF_reg_dec (id, ctyp, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_let (n, bindings, instrs) -> + CDEF_let (n, bindings, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_fundef (id, heap_return, args, instrs) -> + CDEF_fundef (id, heap_return, args, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_startup (id, instrs) -> + CDEF_startup (id, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_finish (id, instrs) -> + CDEF_finish (id, List.concat (List.map (concatmap_instr f) instrs)) + | CDEF_spec (id, ctyps, ctyp) -> CDEF_spec (id, ctyps, ctyp) + | CDEF_type tdef -> CDEF_type tdef + let ctype_def_map_ctyp f = function | CTD_enum (id, ids) -> CTD_enum (id, ids) - | CTD_struct (id, ctors) -> CTD_struct (id, List.map (fun (field, ctyp) -> (field, f ctyp)) ctors) - | CTD_variant (id, ctors) -> CTD_variant (id, List.map (fun (field, ctyp) -> (field, f ctyp)) ctors) + | CTD_struct (id, ctors) -> CTD_struct (id, List.map (fun ((id, ctyps), ctyp) -> ((id, List.map f ctyps), f ctyp)) ctors) + | CTD_variant (id, ctors) -> CTD_variant (id, List.map (fun ((id, ctyps), ctyp) -> ((id, List.map f ctyps), f ctyp)) ctors) (** Map over each ctyp in a cdef using map_instr_ctyp *) let cdef_map_ctyp f = function @@ -834,8 +915,8 @@ and cval_ctyp = function begin match cval_ctyp cval with | CT_struct (id, ctors) -> begin - try snd (List.find (fun (id, ctyp) -> Util.zencode_string (string_of_id id) = field) ctors) with - | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ field) + try snd (List.find (fun (uid, ctyp) -> UId.compare uid field = 0) ctors) with + | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ string_of_uid field) end | ctyp -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Inavlid type for V_field " ^ full_string_of_ctyp ctyp) end @@ -849,8 +930,8 @@ let rec clexp_ctyp = function begin match clexp_ctyp clexp with | CT_struct (id, ctors) -> begin - try snd (List.find (fun (id, ctyp) -> string_of_id id = field) ctors) with - | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ field) + try snd (List.find (fun (uid, ctyp) -> UId.compare uid field = 0) ctors) with + | Not_found -> failwith ("Struct type " ^ string_of_id id ^ " does not have a constructor " ^ string_of_uid field) end | ctyp -> failwith ("Bad ctyp for CL_field " ^ string_of_ctyp ctyp) end @@ -880,8 +961,9 @@ let rec instr_ctyps (I_aux (instr, aux)) = CTSet.union (instrs_ctyps instrs1) (instrs_ctyps instrs2) |> CTSet.add (cval_ctyp cval) |> CTSet.add ctyp - | I_funcall (clexp, _, _, cvals) -> + | I_funcall (clexp, _, (_, ctyps), cvals) -> List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty (List.map cval_ctyp cvals) + |> CTSet.union (CTSet.of_list ctyps) |> CTSet.add (clexp_ctyp clexp) | I_copy (clexp, cval) -> CTSet.add (clexp_ctyp clexp) (CTSet.singleton (cval_ctyp cval)) diff --git a/src/sail.ml b/src/sail.ml index da711e8d..e792e652 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -349,9 +349,6 @@ let options = Arg.align ([ ( "-ddump_rewrite_ast", Arg.String (fun l -> opt_ddump_rewrite_ast := Some (l, 0); Specialize.opt_ddump_spec_ast := Some (l, 0)), "<prefix> (debug) dump the ast after each rewriting step to <prefix>_<i>.lem"); - ( "-ddump_flow_graphs", - Arg.Set Jib_compile.opt_debug_flow_graphs, - " (debug) dump flow analysis for Sail functions when compiling to C"); ( "-ddump_smt_graphs", Arg.Set Jib_smt.opt_debug_graphs, " (debug) dump flow analysis for properties when generating SMT"); |
