diff options
Diffstat (limited to 'src/jib')
| -rw-r--r-- | src/jib/jib_compile.ml | 25 | ||||
| -rw-r--r-- | src/jib/jib_compile.mli | 6 | ||||
| -rw-r--r-- | src/jib/jib_optimize.ml | 17 | ||||
| -rw-r--r-- | src/jib/jib_smt.ml | 711 | ||||
| -rw-r--r-- | src/jib/jib_ssa.ml | 107 | ||||
| -rw-r--r-- | src/jib/jib_ssa.mli | 9 | ||||
| -rw-r--r-- | src/jib/jib_util.ml | 46 |
7 files changed, 873 insertions, 48 deletions
diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index 4a72ffff..13e7334a 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -151,7 +151,9 @@ type ctx = letbinds : int list; no_raw : bool; convert_typ : ctx -> typ -> ctyp; - optimize_anf : ctx -> typ aexp -> typ aexp + optimize_anf : ctx -> typ aexp -> typ aexp; + specialize_calls : bool; + ignore_64 : bool } let initial_ctx ~convert_typ:convert_typ ~optimize_anf:optimize_anf env = @@ -164,7 +166,9 @@ let initial_ctx ~convert_typ:convert_typ ~optimize_anf:optimize_anf env = letbinds = []; no_raw = false; convert_typ = convert_typ; - optimize_anf = optimize_anf + optimize_anf = optimize_anf; + specialize_calls = false; + ignore_64 = false } let ctyp_of_typ ctx typ = ctx.convert_typ ctx typ @@ -197,6 +201,9 @@ let rec compile_aval l ctx = function | AV_lit (L_aux (L_string str, _), typ) -> [], (F_lit (V_string (String.escaped str)), ctyp_of_typ ctx typ), [] + | AV_lit (L_aux (L_num n, _), typ) when ctx.ignore_64 -> + [], (F_lit (V_int n), ctyp_of_typ ctx typ), [] + | AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal (min_int 64) n && Big_int.less_equal n (max_int 64) -> let gs = ngensym () in [iinit CT_lint gs (F_lit (V_int n), CT_fint 64)], @@ -258,7 +265,7 @@ let rec compile_aval l ctx = function raise (Reporting.err_general l "Encountered empty vector literal") (* Convert a small bitvector to a uint64_t literal. *) - | AV_vector (avals, typ) when is_bitvector avals && List.length avals <= 64 -> + | AV_vector (avals, typ) when is_bitvector avals && (List.length avals <= 64 || ctx.ignore_64) -> begin let bitstring = F_lit (V_bits (List.map value_of_aval_bit avals)) in let len = List.length avals in @@ -383,7 +390,7 @@ let compile_funcall l ctx id args typ = let have_ctyp = cval_ctyp cval in if is_polymorphic ctyp then (F_poly (fst cval), have_ctyp) - else if ctyp_equal ctyp have_ctyp then + else if ctx.specialize_calls || ctyp_equal ctyp have_ctyp then cval else let gs = ngensym () in @@ -398,7 +405,7 @@ let compile_funcall l ctx id args typ = List.rev !setup, begin fun clexp -> - if ctyp_equal (clexp_ctyp clexp) ret_ctyp then + if ctx.specialize_calls || ctyp_equal (clexp_ctyp clexp) ret_ctyp then ifuncall clexp id setup_args else let gs = ngensym () in @@ -450,7 +457,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label = | AP_tup apats, (frag, ctyp) -> begin - let get_tup n ctyp = (F_field (frag, "ztup" ^ string_of_int n), ctyp) in + let get_tup n ctyp = (F_tuple_member (frag, List.length apats, n), ctyp) in let fold (instrs, cleanup, n, ctx) apat ctyp = let instrs', cleanup', ctx = compile_match ctx apat (get_tup n ctyp) case_label in instrs @ instrs', cleanup' @ cleanup, n + 1, ctx @@ -1039,12 +1046,16 @@ let fix_early_return ret instrs = let end_function_label = label "end_function_" in let is_return_recur (I_aux (instr, _)) = match instr with - | I_return _ | I_undefined _ | I_if _ | I_block _ -> true + | I_return _ | I_undefined _ | I_if _ | I_block _ | I_try_block _ -> true | _ -> false in let rec rewrite_return historic instrs = match instr_split_at is_return_recur instrs with | instrs, [] -> instrs + | before, I_aux (I_try_block instrs, _) :: after -> + before + @ [itry_block (rewrite_return (historic @ before) instrs)] + @ rewrite_return (historic @ before) after | before, I_aux (I_block instrs, _) :: after -> before @ [iblock (rewrite_return (historic @ before) instrs)] diff --git a/src/jib/jib_compile.mli b/src/jib/jib_compile.mli index a0cacc3c..12cd4a63 100644 --- a/src/jib/jib_compile.mli +++ b/src/jib/jib_compile.mli @@ -64,7 +64,7 @@ val opt_debug_flow_graphs : bool ref val opt_debug_function : string ref val ngensym : unit -> name - + (** {2 Jib context} *) (** Context for compiling Sail to Jib. We need to pass a (global) @@ -82,7 +82,9 @@ type ctx = letbinds : int list; no_raw : bool; convert_typ : ctx -> typ -> ctyp; - optimize_anf : ctx -> typ aexp -> typ aexp + optimize_anf : ctx -> typ aexp -> typ aexp; + specialize_calls : bool; + ignore_64 : bool } val initial_ctx : diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml index 73b175a1..f9829dfd 100644 --- a/src/jib/jib_optimize.ml +++ b/src/jib/jib_optimize.ml @@ -160,6 +160,7 @@ let rec frag_subst id subst = function | F_unary (op, frag) -> F_unary (op, frag_subst id subst frag) | F_call (op, frags) -> F_call (op, List.map (frag_subst id subst) frags) | F_field (frag, field) -> F_field (frag_subst id subst frag, field) + | F_tuple_member (frag, len, n) -> F_tuple_member (frag_subst id subst frag, len, n) | F_raw str -> F_raw str | F_ctor_kind (frag, ctor, unifiers, ctyp) -> F_ctor_kind (frag_subst id subst frag, ctor, unifiers, ctyp) | F_ctor_unwrap (ctor, unifiers, frag) -> F_ctor_unwrap (ctor, unifiers, frag_subst id subst frag) @@ -212,8 +213,10 @@ let rec instrs_subst id subst = let rec clexp_subst id subst = function | CL_id (id', ctyp) when Name.compare id id' = 0 -> - assert (ctyp_equal ctyp (clexp_ctyp subst)); - subst + if ctyp_equal ctyp (clexp_ctyp subst) then + subst + else + subst | CL_id (id', ctyp) -> CL_id (id', ctyp) | CL_field (clexp, field) -> CL_field (clexp_subst id subst clexp, field) | CL_addr clexp -> CL_addr (clexp_subst id subst clexp) @@ -245,6 +248,15 @@ let inline cdefs should_inline instrs = | instr -> instr in + let fix_labels = + let fix_label l = "inline" ^ string_of_int !inlines ^ "_" ^ l in + function + | I_aux (I_goto label, aux) -> I_aux (I_goto (fix_label label), aux) + | I_aux (I_label label, aux) -> I_aux (I_label (fix_label label), aux) + | I_aux (I_jump (cval, label), aux) -> I_aux (I_jump (cval, fix_label label), aux) + | instr -> instr + in + let rec inline_instr = function | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline function_id -> begin match find_function function_id cdefs with @@ -252,6 +264,7 @@ let inline cdefs should_inline instrs = incr inlines; let inline_label = label "end_inline_" in let body = List.fold_right2 instrs_subst (List.map name ids) (List.map fst args) body in + let body = List.map (map_instr fix_labels) body in let body = List.map (map_instr (replace_end inline_label)) body in let body = List.map (map_instr (replace_return clexp)) body in I_aux (I_block (body @ [ilabel inline_label]), aux) diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml new file mode 100644 index 00000000..01240754 --- /dev/null +++ b/src/jib/jib_smt.ml @@ -0,0 +1,711 @@ +open Anf +open Ast +open Ast_util +open Jib +open Jib_util +open Smtlib + +let zencode_upper_id id = Util.zencode_upper_string (string_of_id id) +let zencode_id id = Util.zencode_string (string_of_id id) +let zencode_name id = string_of_name ~deref_current_exception:false ~zencode:true id + +let lbits_index = ref 8 + +let lbits_size () = Util.power 2 !lbits_index + +let lint_size = ref 64 + +let smt_unit = mk_enum "Unit" ["Unit"] +let smt_lbits = mk_record "Bits" [("size", Bitvec !lbits_index); ("bits", Bitvec (lbits_size ()))] + +let rec required_width n = + if Big_int.equal n Big_int.zero then + 1 + else + 1 + required_width (Big_int.shift_right n 1) + +let rec smt_ctyp = function + | CT_constant n -> Bitvec (required_width n) + | CT_fint n -> Bitvec n + | CT_lint -> Bitvec !lint_size + | CT_unit -> smt_unit + | CT_bit -> Bitvec 1 + | CT_fbits (n, _) -> Bitvec n + | CT_sbits (n, _) -> smt_lbits + | CT_lbits _ -> smt_lbits + | CT_bool -> Bool + | CT_enum (id, elems) -> + mk_enum (zencode_upper_id id) (List.map zencode_id elems) + | CT_struct (id, fields) -> + mk_record (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctyp)) fields) + | CT_variant (id, ctors) -> + mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctyp)) ctors) + | CT_tup ctyps -> Tuple (List.map smt_ctyp ctyps) + | CT_vector (_, ctyp) -> Array (Bitvec 8, smt_ctyp ctyp) + | CT_string -> Bitvec 64 + | ctyp -> failwith ("Unhandled ctyp: " ^ string_of_ctyp ctyp) + +let bvint sz x = + if sz mod 4 = 0 then + let hex = Printf.sprintf "%X" x in + let padding = String.make (sz / 4 - String.length hex) '0' in + Hex (padding ^ hex) + else + failwith "Bad len" + +let smt_value vl ctyp = + let open Value2 in + match vl, ctyp with + | V_bits bs, _ -> + begin match Sail2_values.hexstring_of_bits bs with + | Some s -> Hex (Xstring.implode s) + | None -> Bin (Xstring.implode (List.map Sail2_values.bitU_char bs)) + end + | V_bool b, _ -> Bool_lit b + | V_int n, CT_constant m -> bvint (required_width n) (Big_int.to_int n) + | V_int n, CT_fint sz -> bvint sz (Big_int.to_int n) + | V_bit Sail2_values.B0, CT_bit -> Bin "0" + | V_bit Sail2_values.B1, CT_bit -> Bin "1" + | V_unit, _ -> Var "unit" + + | vl, _ -> failwith ("Bad literal " ^ string_of_value vl) + +let zencode_ctor ctor_id unifiers = + match unifiers with + | [] -> + zencode_id ctor_id + | _ -> + Util.zencode_string (string_of_id ctor_id ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers) + +let rec smt_cval env cval = + match cval with + | F_lit vl, ctyp -> smt_value vl ctyp + | frag, _ -> smt_fragment env frag + +and smt_fragment env frag = + match frag with + | F_id (Name (id, _) as ssa_id) -> + begin match Type_check.Env.lookup_id id env with + | Enum _ -> Var (zencode_id id) + | _ -> Var (zencode_name ssa_id) + end + | F_id ssa_id -> Var (zencode_name ssa_id) + | F_op (frag1, "!=", frag2) -> + Fn ("not", [Fn ("=", [smt_fragment env frag1; smt_fragment env frag2])]) + | F_unary ("!", frag) -> + Fn ("not", [smt_cval env (frag, CT_bool)]) + | F_ctor_kind (union, ctor_id, unifiers, _) -> + Fn ("not", [Tester (zencode_ctor ctor_id unifiers, smt_fragment env union)]) + | F_ctor_unwrap (ctor_id, unifiers, union) -> + Fn ("un" ^ zencode_ctor ctor_id unifiers, [smt_fragment env union]) + | F_field (union, field) -> + Fn ("un" ^ field, [smt_fragment env union]) + | F_tuple_member (frag, len, n) -> + Fn (Printf.sprintf "tup_%d_%d" len n, [smt_fragment env frag]) + | frag -> failwith ("Unrecognised fragment " ^ string_of_fragment ~zencode:false frag) + +let builtin_zeros env cval = function + | CT_fbits (n, _) -> bvzero n + | CT_lbits _ -> Fn ("Bits", [extract (!lbits_index - 1) 0 (smt_cval env cval); bvzero (lbits_size ())]) + | _ -> failwith "Cannot compile zeros" + +let builtin_zero_extend env vbits vlen ret_ctyp = + match cval_ctyp vbits, ret_ctyp with + | CT_fbits (n, _), CT_fbits (m, _) when n = m -> + smt_cval env vbits + | CT_fbits (n, _), CT_fbits (m, _) -> + let bv = smt_cval env vbits in + Fn ("concat", [bvzero (m - n); bv]) + | _ -> failwith "Cannot compile zero_extend" + +let builtin_sign_extend env vbits vlen ret_ctyp = + match cval_ctyp vbits, ret_ctyp with + | CT_fbits (n, _), CT_fbits (m, _) when n = m -> + smt_cval env vbits + | CT_fbits (n, _), CT_fbits (m, _) -> + let bv = smt_cval env vbits in + let top_bit_one = Fn ("=", [Extract (n - 1, n - 1, bv); Bin "1"]) in + Ite (top_bit_one, Fn ("concat", [bvones (m - n); bv]), Fn ("concat", [bvzero (m - n); bv])) + | _ -> failwith "Cannot compile zero_extend" + +let int_size = function + | CT_constant n -> required_width n + | CT_fint sz -> sz + | CT_lint -> lbits_size () + | _ -> failwith "Argument to int_size must be an integer" + +(* [bvzeint esz cval] (BitVector Zero Extend INTeger), takes a cval + which must be an integer type (either CT_fint, or CT_lint), and + produces a bitvector which is either zero extended or truncated to + exactly esz bits. *) +let bvzeint env esz cval = + let sz = int_size (cval_ctyp cval) in + match fst cval with + | F_lit (V_int n) -> + bvint esz (Big_int.to_int n) + | _ -> + let smt = smt_cval env cval in + if esz = sz then + smt + else if esz > sz then + Fn ("concat", [bvzero (esz - sz); smt]) + else + Extract (esz - 1, 0, smt) + +let builtin_shift shiftop env vbits vshift ret_ctyp = + match cval_ctyp vbits with + | CT_fbits (n, _) -> + let bv = smt_cval env vbits in + let len = bvzeint env n vshift in + Fn (shiftop, [bv; len]) + + | CT_lbits _ -> + let bv = smt_cval env vbits in + let shift = bvzeint env (lbits_size ()) vshift in + Fn ("Bits", [Fn ("len", [bv]); Fn (shiftop, [Fn ("contents", [bv]); shift])]) + + | _ -> failwith ("Cannot compile shift: " ^ shiftop) + +let builtin_or_bits env v1 v2 ret_ctyp = + match cval_ctyp v1, cval_ctyp v2 with + | CT_fbits (n, _), CT_fbits (m, _) -> + assert (n = m); + let smt1 = smt_cval env v1 in + let smt2 = smt_cval env v2 in + bvor smt1 smt2 + + | CT_lbits _, CT_lbits _ -> + let smt1 = smt_cval env v1 in + let smt2 = smt_cval env v2 in + Fn ("Bits", [Fn ("len", [smt1]); bvor (Fn ("contents", [smt1])) (Fn ("contents", [smt2]))]) + + | _ -> failwith "Cannot compile or_bits" + +let builtin_append env v1 v2 ret_ctyp = + match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + | CT_fbits (n, _), CT_fbits (m, _), CT_fbits (o, _) -> + assert (n + m = o); + let smt1 = smt_cval env v1 in + let smt2 = smt_cval env v2 in + Fn ("concat", [smt1; smt2]) + + | CT_fbits (n, _), CT_lbits _, CT_lbits _ -> + let smt1 = smt_cval env v1 in + let smt2 = smt_cval env v2 in + let x = Fn ("concat", [bvzero (lbits_size () - n); smt1]) in + let shift = Fn ("concat", [bvzero (lbits_size () - !lbits_index); Fn ("len", [smt2])]) in + Fn ("Bits", [bvadd (bvint !lbits_index n) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))]) + + | CT_lbits _, CT_lbits _, CT_lbits _ -> + let smt1 = smt_cval env v1 in + let smt2 = smt_cval env v2 in + let x = Fn ("contents", [smt1]) in + let shift = Fn ("concat", [bvzero (lbits_size () - !lbits_index); Fn ("len", [smt2])]) in + Fn ("Bits", [bvadd (Fn ("len", [smt1])) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))]) + + | _ -> failwith "Cannot compile append" + +let builtin_length env v ret_ctyp = + match cval_ctyp v, ret_ctyp with + | CT_lbits _, CT_fint m -> + let sz = !lbits_index in + let len = Fn ("len", [smt_cval env v]) in + if m = sz then + len + else if m > sz then + Fn ("concat", [bvzero (m - sz); len]) + else + Extract (m - 1, 0, len) + + | _, _ -> failwith "Cannot compile length" + +let builtin_vector_subrange env vec i j ret_ctyp = + match cval_ctyp vec, cval_ctyp i, cval_ctyp j with + | CT_fbits (n, _), CT_constant i, CT_constant j -> + Extract (Big_int.to_int i, Big_int.to_int j, smt_cval env vec) + + | _ -> failwith "Cannot compile vector subrange" + +let builtin_vector_access env vec i ret_ctyp = + match cval_ctyp vec, cval_ctyp i, ret_ctyp with + | CT_fbits (n, _), CT_constant i, CT_bit -> + Extract (Big_int.to_int i, Big_int.to_int i, smt_cval env vec) + + | _ -> failwith "Cannot compile vector subrange" + +let builtin_unsigned env v ret_ctyp = + match cval_ctyp v, ret_ctyp with + | CT_fbits (n, _), CT_fint m -> + let smt = smt_cval env v in + Fn ("concat", [bvzero (m - n); smt]) + + | _ -> failwith "Cannot compile unsigned" + +let builtin_eq_int env v1 v2 = + match cval_ctyp v1, cval_ctyp v2 with + | CT_fint m, CT_constant c -> + Fn ("=", [smt_cval env v1; bvint m (Big_int.to_int c)]) + + | _ -> failwith "Cannot compile eq_int" + +let builtin_add_bits env v1 v2 ret_ctyp = + match cval_ctyp v1, cval_ctyp v2, ret_ctyp with + | CT_fbits (n, _), CT_fbits (m, _), CT_fbits (o, _) -> + assert (n = m && m = o); + Fn ("bvadd", [smt_cval env v1; smt_cval env v2]) + + | _ -> failwith "Cannot compile add_bits" + +let smt_primop env name args ret_ctyp = + match name, args, ret_ctyp with + | "eq_bits", [v1; v2], _ -> + let smt1 = smt_cval env v1 in + let smt2 = smt_cval env v2 in + Fn ("=", [smt1; smt2]) + | "eq_bit", [v1; v2], _ -> + let smt1 = smt_cval env v1 in + let smt2 = smt_cval env v2 in + Fn ("=", [smt1; smt2]) + + | "not", [v], _ -> Fn ("not", [smt_cval env v]) + + | "zeros", [v1], _ -> builtin_zeros env v1 ret_ctyp + | "zero_extend", [v1; v2], _ -> builtin_zero_extend env v1 v2 ret_ctyp + | "sign_extend", [v1; v2], _ -> builtin_sign_extend env v1 v2 ret_ctyp + | "shiftl", [v1; v2], _ -> builtin_shift "bvshl" env v1 v2 ret_ctyp + | "shiftr", [v1; v2], _ -> builtin_shift "bvlshr" env v1 v2 ret_ctyp + | "or_bits", [v1; v2], _ -> builtin_or_bits env v1 v2 ret_ctyp + | "add_bits", [v1; v2], _ -> builtin_add_bits env v1 v2 ret_ctyp + | "append", [v1; v2], _ -> builtin_append env v1 v2 ret_ctyp + | "length", [v], ret_ctyp -> builtin_length env v ret_ctyp + | "vector_access", [v1; v2], ret_ctyp -> builtin_vector_access env v1 v2 ret_ctyp + | "vector_subrange", [v1; v2; v3], ret_ctyp -> builtin_vector_subrange env v1 v2 v3 ret_ctyp + | "sail_unsigned", [v], ret_ctyp -> builtin_unsigned env v ret_ctyp + | "eq_int", [v1; v2], _ -> builtin_eq_int env v1 v2 + + | _ -> failwith ("Bad primop " ^ name ^ " " ^ Util.string_of_list ", " string_of_ctyp (List.map snd args) ^ " -> " ^ string_of_ctyp ret_ctyp) + +let rec smt_conversion from_ctyp to_ctyp x = + match from_ctyp, to_ctyp with + | _, _ when ctyp_equal from_ctyp to_ctyp -> x + | _, _ -> failwith "Bad conversion" + +let define_const id ctyp exp = Define_const (zencode_name id, smt_ctyp ctyp, exp) +let declare_const id ctyp = Declare_const (zencode_name id, smt_ctyp ctyp) + +let smt_ctype_def = function + | CTD_enum (id, elems) -> + [declare_datatypes (mk_enum (zencode_upper_id id) (List.map zencode_id elems))] + + | CTD_struct (id, fields) -> + [declare_datatypes + (mk_record (zencode_upper_id id) + (List.map (fun (field, ctyp) -> zencode_id field, smt_ctyp ctyp) fields))] + + | CTD_variant (id, ctors) -> + [declare_datatypes + (mk_variant (zencode_upper_id id) + (List.map (fun (ctor, ctyp) -> zencode_id ctor, smt_ctyp ctyp) ctors))] + +let rec generate_ctype_defs = function + | CDEF_type ctd :: cdefs -> smt_ctype_def ctd :: generate_ctype_defs cdefs + | _ :: cdefs -> generate_ctype_defs cdefs + | [] -> [] + +let rec generate_reg_decs inits = function + | CDEF_reg_dec (id, ctyp, _) :: cdefs when not (NameMap.mem (Name (id, 0)) inits)-> + Declare_const (zencode_name (Name (id, 0)), smt_ctyp ctyp) + :: generate_reg_decs inits cdefs + | _ :: cdefs -> generate_reg_decs inits cdefs + | [] -> [] + +(**************************************************************************) +(* 2. Converting sail types to Jib types for SMT *) +(**************************************************************************) + +let max_int n = Big_int.pred (Big_int.pow_int_positive 2 (n - 1)) +let min_int n = Big_int.negate (Big_int.pow_int_positive 2 (n - 1)) + +(** Convert a sail type into a C-type. This function can be quite + slow, because it uses ctx.local_env and SMT to analyse the Sail + types and attempts to fit them into the smallest possible C + types, provided ctx.optimize_smt is true (default) **) +let rec ctyp_of_typ ctx typ = + let open Ast in + let open Type_check in + let open Jib_compile in + let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.tc_env typ in + match typ_aux with + | Typ_id id when string_of_id id = "bit" -> CT_bit + | Typ_id id when string_of_id id = "bool" -> CT_bool + | Typ_id id when string_of_id id = "int" -> CT_lint + | Typ_id id when string_of_id id = "nat" -> CT_lint + | Typ_id id when string_of_id id = "unit" -> CT_unit + | Typ_id id when string_of_id id = "string" -> CT_string + | Typ_id id when string_of_id id = "real" -> CT_real + + | Typ_app (id, _) when string_of_id id = "atom_bool" -> CT_bool + + | Typ_app (id, args) when string_of_id id = "itself" -> + ctyp_of_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l)) + | Typ_app (id, _) when string_of_id id = "range" || string_of_id id = "atom" || string_of_id id = "implicit" -> + begin match destruct_range Env.empty typ with + | None -> assert false (* Checked if range type in guard *) + | Some (kids, constr, n, m) -> + let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env } in + match nexp_simp n, nexp_simp m with + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + when n = m -> + CT_constant n + | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _) + when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) -> + CT_fint 64 + | n, m -> + if prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) then + CT_fint 64 + else + CT_lint + end + + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" -> + CT_list (ctyp_of_typ ctx typ) + + (* When converting a sail bitvector type into C, we have three options in order of efficiency: + - If the length is obviously static and smaller than 64, use the fixed bits type (aka uint64_t), fbits. + - If the length is less than 64, then use a small bits type, sbits. + - If the length may be larger than 64, use a large bits type lbits. *) + | Typ_app (id, [A_aux (A_nexp n, _); + A_aux (A_order ord, _); + A_aux (A_typ (Typ_aux (Typ_id vtyp_id, _)), _)]) + when string_of_id id = "vector" && string_of_id vtyp_id = "bit" -> + let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in + begin match nexp_simp n with + | Nexp_aux (Nexp_constant n, _) -> CT_fbits (Big_int.to_int n, direction) + | _ -> CT_lbits direction + end + + | Typ_app (id, [A_aux (A_nexp n, _); + A_aux (A_order ord, _); + A_aux (A_typ typ, _)]) + when string_of_id id = "vector" -> + let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in + CT_vector (direction, ctyp_of_typ ctx typ) + + | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> + CT_ref (ctyp_of_typ ctx typ) + + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> Bindings.bindings) + | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> Bindings.bindings) + | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements) + + | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs) + + | Typ_exist _ -> + (* Use Type_check.destruct_exist when optimising with SMT, to + ensure that we don't cause any type variable clashes in + local_env, and that we can optimize the existential based upon + it's constraints. *) + begin match destruct_exist (Env.expand_synonyms ctx.local_env typ) with + | Some (kids, nc, typ) -> + let env = add_existential l kids nc ctx.local_env in + ctyp_of_typ { ctx with local_env = env } typ + | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!") + end + + | Typ_var kid -> CT_poly + + | _ -> raise (Reporting.err_unreachable l __POS__ ("No C type for type " ^ string_of_typ typ)) + +(**************************************************************************) +(* 3. Optimization of primitives and literals *) +(**************************************************************************) + +let hex_char = + let open Sail2_values in + function + | '0' -> [B0; B0; B0; B0] + | '1' -> [B0; B0; B0; B1] + | '2' -> [B0; B0; B1; B0] + | '3' -> [B0; B0; B1; B1] + | '4' -> [B0; B1; B0; B0] + | '5' -> [B0; B1; B0; B1] + | '6' -> [B0; B1; B1; B0] + | '7' -> [B0; B1; B1; B1] + | '8' -> [B1; B0; B0; B0] + | '9' -> [B1; B0; B0; B1] + | 'A' | 'a' -> [B1; B0; B1; B0] + | 'B' | 'b' -> [B1; B0; B1; B1] + | 'C' | 'c' -> [B1; B1; B0; B0] + | 'D' | 'd' -> [B1; B1; B0; B1] + | 'E' | 'e' -> [B1; B1; B1; B0] + | 'F' | 'f' -> [B1; B1; B1; B1] + | _ -> failwith "Invalid hex character" + +let literal_to_fragment (L_aux (l_aux, _) as lit) = + match l_aux with + | L_num n -> Some (F_lit (V_int n), CT_constant n) + | L_hex str when String.length str <= 16 -> + let content = Util.string_to_list str |> List.map hex_char |> List.concat in + Some (F_lit (V_bits content), CT_fbits (String.length str * 4, true)) + | L_unit -> Some (F_lit V_unit, CT_unit) + | L_true -> Some (F_lit (V_bool true), CT_bool) + | L_false -> Some (F_lit (V_bool false), CT_bool) + | _ -> None + +let c_literals ctx = + let rec c_literal env l = function + | AV_lit (lit, typ) as v -> + begin match literal_to_fragment lit with + | Some (frag, ctyp) -> AV_C_fragment (frag, typ, ctyp) + | None -> v + end + | AV_tuple avals -> AV_tuple (List.map (c_literal env l) avals) + | v -> v + in + map_aval c_literal + +(**************************************************************************) +(* 3. Generating SMT *) +(**************************************************************************) + +(* When generating SMT when we encounter joins between two or more + blocks such as in the example below, we have to generate a muxer + that chooses the correct value of v_n or v_m to assign to v_o. We + use the pi nodes that contain the global path condition for each + block to generate an if-then-else for each phi function. The order + of the arguments to each phi function is based on the graph node + index for the predecessor nodes. + + +---------------+ +---------------+ + | pi(cond_1) | | pi(cond_2) | + | ... | | ... | + | Basic block 1 | | Basic block 2 | + +---------------+ +---------------+ + \ / + \ / + +---------------------+ + | v/o = phi(v/n, v/m) | + | ... | + +---------------------+ + + would generate: + + (define-const v/o (ite cond_1 v/n v/m_)) +*) +let smt_ssanode env cfg preds = + let open Jib_ssa in + function + | Pi _ -> [] + | Phi (id, ctyp, ids) -> + let get_pi n = + match get_vertex cfg n with + | Some ((ssanodes, _), _, _) -> + List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes) + | None -> failwith "Predecessor node does not exist" + in + let pis = List.map get_pi (IntSet.elements preds) in + let mux = + List.fold_right2 (fun pi id chain -> + let pathcond = + match pi with + | [cval] -> smt_cval env cval + | _ -> Fn ("and", List.map (smt_cval env) pi) + in + match chain with + | Some smt -> + Some (Ite (pathcond, Var (zencode_name id), smt)) + | None -> + Some (Var (zencode_name id))) + pis ids None + in + match mux with + | None -> [] + | Some mux -> + [Define_const (zencode_name id, smt_ctyp ctyp, mux)] + +(* For any complex l-expression we need to turn it into a + read-modify-write in the SMT solver. The SSA transform turns CL_id + nodes into CL_rmw (read, write, ctyp) nodes when CL_id is wrapped + in any other l-expression. The read and write must have the same + name but different SSA numbers. +*) +let rec rmw_write = function + | CL_rmw (_, write, ctyp) -> write, ctyp + | CL_id _ -> assert false + | CL_tuple (clexp, _) -> rmw_write clexp + | clexp -> + failwith (Pretty_print_sail.to_string (pp_clexp clexp)) + +let rmw_read = function + | CL_rmw (read, _, _) -> zencode_name read + | _ -> assert false + +let rmw_modify smt = function + | CL_tuple (clexp, n) -> + let ctyp = clexp_ctyp clexp in + begin match ctyp with + | CT_tup ctyps -> + let len = List.length ctyps in + let set_tup i = + if i == n then + smt + else + Fn (Printf.sprintf "tup_%d_%d" len i, [Var (rmw_read clexp)]) + in + Fn ("tup4", List.init len set_tup) + | _ -> + failwith "Tuple modify does not have tuple type" + end + | _ -> assert false + +(* For a basic block (contained in a control-flow node / cfnode), we + turn the instructions into a sequence of define-const and + declare-const expressions. Because we are working with a SSA graph, + each variable is guaranteed to only be declared once. +*) +let smt_instr env = + let open Type_check in + function + | I_aux (I_funcall (CL_id (id, ret_ctyp), _, function_id, args), _) -> + if Env.is_extern function_id env "c" then + let name = Env.get_extern function_id env "c" in + let value = smt_primop env name args ret_ctyp in + [define_const id ret_ctyp value] + else + let smt_args = List.map (smt_cval env) args in + [define_const id ret_ctyp (Fn (zencode_id function_id, smt_args))] + + | I_aux (I_init (ctyp, id, cval), _) | I_aux (I_copy (CL_id (id, ctyp), cval), _) -> + [define_const id ctyp + (smt_conversion (cval_ctyp cval) ctyp (smt_cval env cval))] + + | I_aux (I_copy (clexp, cval), _) -> + let smt = smt_cval env cval in + let write, ctyp = rmw_write clexp in + [define_const write ctyp (rmw_modify smt clexp)] + + | I_aux (I_decl (ctyp, id), _) -> + [declare_const id ctyp] + + | I_aux (I_end id, _) -> + [Assert (Var (zencode_name id))] + + | I_aux (I_clear _, _) -> [] + + | I_aux (I_match_failure, _) -> [] + + | instr -> + failwith ("Cannot translate: " ^ Pretty_print_sail.to_string (pp_instr instr)) + +let smt_cfnode all_cdefs env = + let open Jib_ssa in + function + | CF_start inits -> + let smt_reg_decs = generate_reg_decs inits all_cdefs in + let smt_start (id, ctyp) = + match id with + | Have_exception _ -> define_const id ctyp (Bool_lit false) + | _ -> declare_const id ctyp + in + smt_reg_decs @ List.map smt_start (NameMap.bindings inits) + | CF_block instrs -> + List.concat (List.map (smt_instr env) instrs) + (* We can ignore any non basic-block/start control-flow nodes *) + | _ -> [] + +let rec find_function id = function + | CDEF_fundef (id', heap_return, args, body) :: _ when Id.compare id id' = 0 -> + Some (heap_return, args, body) + | _ :: cdefs -> + find_function id cdefs + | [] -> None + +let smt_cdef out_chan env all_cdefs = function + | CDEF_spec (function_id, arg_ctyps, ret_ctyp) + when string_of_id function_id = "check_sat" -> + begin match find_function function_id all_cdefs with + | Some (None, args, instrs) -> + let open Jib_ssa in + let smt_args = + List.map2 (fun id ctyp -> declare_const (Name (id, 0)) ctyp) args arg_ctyps + in + output_smt_defs out_chan smt_args; + + let instrs = + let open Jib_optimize in + instrs + (* |> optimize_unit *) + |> inline all_cdefs (fun _ -> true) + |> flatten_instrs + in + + let str = Pretty_print_sail.to_string PPrint.(separate_map hardline Jib_util.pp_instr instrs) in + prerr_endline str; + + let start, cfg = ssa instrs in + let chan = open_out "smt_ssa.gv" in + make_dot chan cfg; + close_out chan; + + let visit_order = topsort cfg in + + List.iter (fun n -> + begin match get_vertex cfg n with + | None -> () + | Some ((ssanodes, cfnode), preds, succs) -> + let muxers = + ssanodes |> List.map (smt_ssanode env cfg preds) |> List.concat + in + let basic_block = smt_cfnode all_cdefs env cfnode in + output_smt_defs out_chan muxers; + output_smt_defs out_chan basic_block; + end + ) visit_order + + | _ -> failwith "Bad function body" + end + | _ -> () + +let rec smt_cdefs out_chan env ast = + function + | cdef :: cdefs -> + smt_cdef out_chan env ast cdef; + smt_cdefs out_chan env ast cdefs + | [] -> () + +let generate_smt out_chan env ast = + try + let open Jib_compile in + let ctx = + initial_ctx + ~convert_typ:ctyp_of_typ + ~optimize_anf:(fun ctx aexp -> c_literals ctx aexp) + env + in + let t = Profile.start () in + let cdefs, ctx = compile_ast { ctx with specialize_calls = true; ignore_64 = true } ast in + Profile.finish "Compiling to Jib IR" t; + + (* output_string out_chan "(set-option :produce-models true)\n"; *) + output_string out_chan "(set-logic QF_AUFBVDT)\n"; + output_smt_defs out_chan + [declare_datatypes (mk_enum "Unit" ["unit"]); + Declare_tuple 2; + Declare_tuple 3; + Declare_tuple 4; + Declare_tuple 5; + declare_datatypes (mk_record "Bits" [("len", Bitvec !lbits_index); + ("contents", Bitvec (lbits_size ()))]) + + ]; + + let smt_type_defs = List.concat (generate_ctype_defs cdefs) in + output_string out_chan "\n; Sail type definitions\n"; + output_smt_defs out_chan smt_type_defs; + + output_string out_chan "\n; Sail function\n"; + smt_cdefs out_chan env cdefs cdefs; + output_string out_chan "(check-sat)\n" + with + | Type_check.Type_error (_, l, err) -> + raise (Reporting.err_typ l (Type_error.string_of_type_error err)); diff --git a/src/jib/jib_ssa.ml b/src/jib/jib_ssa.ml index a086f0b9..e79c88a1 100644 --- a/src/jib/jib_ssa.ml +++ b/src/jib/jib_ssa.ml @@ -76,7 +76,7 @@ let iter_graph f graph = | Some (x, y, z) -> f x y z | None -> () done - + (** Add a vertex to a graph, returning the node index *) let add_vertex data graph = let n = graph.next in @@ -124,6 +124,52 @@ let reachable roots graph = in IntSet.iter reachable' roots; !visited +let topsort graph = + let marked = ref IntSet.empty in + let temp_marked = ref IntSet.empty in + let list = ref [] in + + let rec visit node = + if IntSet.mem node !temp_marked then + failwith "Not a DAG" + else if IntSet.mem node !marked then + () + else + begin match get_vertex graph node with + | None -> failwith "Node does not exist in topsort" + | Some (_, _, succs) -> + temp_marked := IntSet.add node !temp_marked; + IntSet.iter visit succs; + marked := IntSet.add node !marked; + temp_marked := IntSet.remove node !temp_marked; + list := node :: !list + end + in + + let find_unmarked () = + let unmarked = ref (-1) in + let i = ref 0 in + while !unmarked = -1 && !i < Array.length graph.nodes do + begin match get_vertex graph !i with + | None -> () + | Some _ -> + if not (IntSet.mem !i !marked) then + unmarked := !i + end; + incr i + done; + !unmarked + in + + let rec topsort' () = + let unmarked = find_unmarked () in + if unmarked = -1 then + () + else + (visit unmarked; topsort' ()) + in + topsort' (); !list + let prune visited graph = for i = 0 to graph.next - 1 do match graph.nodes.(i) with @@ -143,7 +189,7 @@ type cf_node = | CF_label of string | CF_block of instr list | CF_guard of cval - | CF_start + | CF_start of ctyp NameMap.t let cval_not (f, ctyp) = (F_unary ("!", f), ctyp) @@ -212,7 +258,7 @@ let control_flow_graph instrs = | [] -> preds in - let start = add_vertex ([], CF_start) graph in + let start = add_vertex ([], CF_start NameMap.empty) graph in let finish = cfg [start] instrs in let visited = reachable (IntSet.singleton start) graph in @@ -426,23 +472,28 @@ let rename_variables graph root children = let counts = ref NameMap.empty in let stacks = ref NameMap.empty in + let phi_zeros = ref NameMap.empty in + + let ssa_name i = function + | Name (id, _) -> Name (id, i) + | Have_exception _ -> Have_exception i + | Current_exception _ -> Current_exception i + | Return _ -> Return i + in + let get_count id = match NameMap.find_opt id !counts with Some n -> n | None -> 0 in let top_stack id = - match NameMap.find_opt id !stacks with Some (x :: _) -> x | (Some [] | None) -> 0 + match NameMap.find_opt id !stacks with Some (x :: _) -> x | Some [] -> 0 | None -> 0 + in + let top_stack_phi id ctyp = + match NameMap.find_opt id !stacks with Some (x :: _) -> x | Some [] -> 0 | None -> (phi_zeros := NameMap.add (ssa_name 0 id) ctyp !phi_zeros; 0) in let push_stack id n = stacks := NameMap.add id (n :: match NameMap.find_opt id !stacks with Some s -> s | None -> []) !stacks in - let ssa_name i = function - | Name (id, _) -> Name (id, i) - | Have_exception _ -> Have_exception i - | Current_exception _ -> Current_exception i - | Return _ -> Return i - in - let rec fold_frag = function | F_id id -> let i = top_stack id in @@ -455,21 +506,29 @@ let rename_variables graph root children = | F_unary (op, f) -> F_unary (op, fold_frag f) | F_call (id, fs) -> F_call (id, List.map fold_frag fs) | F_field (f, field) -> F_field (fold_frag f, field) + | F_tuple_member (f, len, n) -> F_tuple_member (fold_frag f, len, n) | F_raw str -> F_raw str | F_ctor_kind (f, ctor, unifiers, ctyp) -> F_ctor_kind (fold_frag f, ctor, unifiers, ctyp) | F_ctor_unwrap (ctor, unifiers, f) -> F_ctor_unwrap (ctor, unifiers, fold_frag f) | F_poly f -> F_poly (fold_frag f) in - let rec fold_clexp = function + let rec fold_clexp rmw = function + | CL_id (id, ctyp) when rmw -> + let i = top_stack id in + let j = get_count id + 1 in + counts := NameMap.add id j !counts; + push_stack id j; + CL_rmw (ssa_name i id, ssa_name j id, ctyp) | CL_id (id, ctyp) -> let i = get_count id + 1 in counts := NameMap.add id i !counts; push_stack id i; CL_id (ssa_name i id, ctyp) - | CL_field (clexp, field) -> CL_field (fold_clexp clexp, field) - | CL_addr clexp -> CL_addr (fold_clexp clexp) - | CL_tuple (clexp, n) -> CL_tuple (fold_clexp clexp, n) + | CL_rmw _ -> assert false + | CL_field (clexp, field) -> CL_field (fold_clexp true clexp, field) + | CL_addr clexp -> CL_addr (fold_clexp true clexp) + | CL_tuple (clexp, n) -> CL_tuple (fold_clexp true clexp, n) | CL_void -> CL_void in @@ -479,10 +538,10 @@ let rename_variables graph root children = let aux = match aux with | I_funcall (clexp, extern, id, args) -> let args = List.map fold_cval args in - I_funcall (fold_clexp clexp, extern, id, args) + I_funcall (fold_clexp false clexp, extern, id, args) | I_copy (clexp, cval) -> let cval = fold_cval cval in - I_copy (fold_clexp clexp, cval) + I_copy (fold_clexp false clexp, cval) | I_decl (ctyp, id) -> let i = get_count id + 1 in counts := NameMap.add id i !counts; @@ -505,7 +564,7 @@ let rename_variables graph root children = in let ssa_cfnode = function - | CF_start -> CF_start + | CF_start inits -> CF_start inits | CF_block instrs -> CF_block (List.map ssa_instr instrs) | CF_label label -> CF_label label | CF_guard cval -> CF_guard (fold_cval cval) @@ -524,7 +583,7 @@ let rename_variables graph root children = | Phi (id, ctyp, ids) -> let fix_arg k a = if k = j then - let i = top_stack a in + let i = top_stack_phi a ctyp in ssa_name i a else a in @@ -556,7 +615,11 @@ let rename_variables graph root children = IntSet.iter (fun child -> rename child) (children.(n)); stacks := old_stacks in - rename root + rename root; + match graph.nodes.(root) with + | Some ((ssa, CF_start _), preds, succs) -> + graph.nodes.(root) <- Some ((ssa, CF_start !phi_zeros), preds, succs) + | _ -> failwith "root node is not CF_start" let place_pi_functions graph start idom children = let get_guard = function @@ -631,11 +694,11 @@ let string_of_phis = function let string_of_node = function | (phis, CF_label label) -> string_of_phis phis ^ label | (phis, CF_block instrs) -> string_of_phis phis ^ Util.string_of_list "\\l" (fun instr -> String.escaped (Pretty_print_sail.to_string (pp_instr ~short:true instr))) instrs - | (phis, CF_start) -> string_of_phis phis ^ "START" + | (phis, CF_start inits) -> string_of_phis phis ^ "START" | (phis, CF_guard cval) -> string_of_phis phis ^ (String.escaped (Pretty_print_sail.to_string (pp_cval cval))) let vertex_color = function - | (_, CF_start) -> "peachpuff" + | (_, CF_start _) -> "peachpuff" | (_, CF_block _) -> "white" | (_, CF_label _) -> "springgreen" | (_, CF_guard _) -> "yellow" diff --git a/src/jib/jib_ssa.mli b/src/jib/jib_ssa.mli index b146861c..88bb46c0 100644 --- a/src/jib/jib_ssa.mli +++ b/src/jib/jib_ssa.mli @@ -49,6 +49,7 @@ (**************************************************************************) open Array +open Jib_util (** A mutable array based graph type, with nodes indexed by integers. *) type 'a array_graph @@ -58,11 +59,11 @@ type 'a array_graph val make : initial_size:int -> unit -> 'a array_graph module IntSet : Set.S with type elt = int - + val get_vertex : 'a array_graph -> int -> ('a * IntSet.t * IntSet.t) option val iter_graph : ('a -> IntSet.t -> IntSet.t -> unit) -> 'a array_graph -> unit - + (** Add a vertex to a graph, returning the index of the inserted vertex. If the number of vertices exceeds the size of the underlying array, then it is dynamically resized. *) @@ -72,11 +73,13 @@ val add_vertex : 'a -> 'a array_graph -> int if either of the vertices do not exist. *) val add_edge : int -> int -> 'a array_graph -> unit +val topsort : 'a array_graph -> int list + type cf_node = | CF_label of string | CF_block of Jib.instr list | CF_guard of Jib.cval - | CF_start + | CF_start of Jib.ctyp NameMap.t val control_flow_graph : Jib.instr list -> int * int list * ('a list * cf_node) array_graph diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml index 904e0209..2eabdc57 100644 --- a/src/jib/jib_util.ml +++ b/src/jib/jib_util.ml @@ -164,6 +164,7 @@ let rec frag_rename from_id to_id = function | F_op (f1, op, f2) -> F_op (frag_rename from_id to_id f1, op, frag_rename from_id to_id f2) | F_unary (op, f) -> F_unary (op, frag_rename from_id to_id f) | F_field (f, field) -> F_field (frag_rename from_id to_id f, field) + | F_tuple_member (f, len, n) -> F_tuple_member (frag_rename from_id to_id f, len, n) | F_raw raw -> F_raw raw | F_ctor_kind (f, ctor, unifiers, ctyp) -> F_ctor_kind (frag_rename from_id to_id f, ctor, unifiers, ctyp) | F_ctor_unwrap (ctor, unifiers, f) -> F_ctor_unwrap (ctor, unifiers, frag_rename from_id to_id f) @@ -258,7 +259,7 @@ let string_of_value = function | V_bit Sail2_values.BU -> failwith "Undefined bit found in value" | V_string str -> "\"" ^ str ^ "\"" -let string_of_name ?zencode:(zencode=true) = +let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) = let ssa_num n = if n < 0 then "" else ("/" ^ string_of_int n) in function | Name (id, n) -> @@ -267,8 +268,10 @@ let string_of_name ?zencode:(zencode=true) = "have_exception" ^ ssa_num n | Return n -> "return" ^ ssa_num n - | Current_exception n -> + | Current_exception n when dce -> "(*current_exception)" ^ ssa_num n + | Current_exception n -> + "current_exception" ^ ssa_num n let rec string_of_fragment ?zencode:(zencode=true) = function | F_id id -> string_of_name ~zencode:zencode id @@ -278,6 +281,8 @@ let rec string_of_fragment ?zencode:(zencode=true) = function Printf.sprintf "%s(%s)" str (Util.string_of_list ", " (string_of_fragment ~zencode:zencode) frags) | F_field (f, field) -> Printf.sprintf "%s.%s" (string_of_fragment' ~zencode:zencode f) field + | F_tuple_member (f, _, n) -> + Printf.sprintf "%s.ztup%d" (string_of_fragment' ~zencode:zencode f) n | F_op (f1, op, f2) -> Printf.sprintf "%s %s %s" (string_of_fragment' ~zencode:zencode f1) op (string_of_fragment' ~zencode:zencode f2) | F_unary (op, f) -> @@ -314,6 +319,7 @@ and string_of_ctyp = function | CT_sbits (n, true) -> "sbits(" ^ string_of_int n ^ ", dec)" | CT_sbits (n, false) -> "sbits(" ^ string_of_int n ^ ", inc)" | CT_fint n -> "int(" ^ string_of_int n ^ ")" + | CT_constant n -> "constant(" ^ Big_int.to_string n ^ ")" | CT_bit -> "bit" | CT_unit -> "unit" | CT_bool -> "bool" @@ -348,7 +354,7 @@ and full_string_of_ctyp = function | ctyp -> string_of_ctyp ctyp let rec map_ctyp f = function - | (CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ + | (CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string | CT_poly | CT_enum _) as ctyp -> f ctyp | CT_tup ctyps -> f (CT_tup (List.map (map_ctyp f) ctyps)) | CT_ref ctyp -> f (CT_ref (map_ctyp f ctyp)) @@ -365,6 +371,7 @@ let rec ctyp_equal ctyp1 ctyp2 = | CT_fbits (m1, d1), CT_fbits (m2, d2) -> m1 = m2 && d1 = d2 | CT_bit, CT_bit -> true | CT_fint n, CT_fint m -> n = m + | CT_constant n, CT_constant m -> Big_int.equal n m | CT_unit, CT_unit -> true | CT_bool, CT_bool -> true | CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0 @@ -391,6 +398,10 @@ let rec ctyp_compare ctyp1 ctyp2 = | CT_fint _, _ -> 1 | _, CT_fint _ -> -1 + | CT_constant n, CT_constant m -> Big_int.compare n m + | CT_constant _, _ -> 1 + | _, CT_constant _ -> -1 + | CT_fbits (n, ord1), CT_fbits (m, ord2) -> lex_ord (compare n m) (compare ord1 ord2) | CT_fbits _, _ -> 1 | _, CT_fbits _ -> -1 @@ -465,7 +476,7 @@ let rec ctyp_unify ctyp1 ctyp2 = List.concat (List.map2 ctyp_unify (List.map snd fields1) (List.map snd fields2)) else raise (Invalid_argument "ctyp_unify") - + | CT_ref ctyp1, CT_ref ctyp2 -> ctyp_unify ctyp1 ctyp2 | CT_poly, _ -> [ctyp2] @@ -479,6 +490,7 @@ let rec ctyp_suprema = function | CT_fbits (_, d) -> CT_lbits d | CT_sbits (_, d) -> CT_lbits d | CT_fint _ -> CT_lint + | CT_constant _ -> CT_lint | CT_unit -> CT_unit | CT_bool -> CT_bool | CT_real -> CT_real @@ -503,7 +515,7 @@ let rec ctyp_ids = function IdSet.add id (List.fold_left (fun ids (_, ctyp) -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctors) | CT_tup ctyps -> List.fold_left (fun ids ctyp -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctyps | CT_vector (_, ctyp) | CT_list ctyp | CT_ref ctyp -> ctyp_ids ctyp - | CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit + | CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit | CT_bool | CT_real | CT_bit | CT_string | CT_poly -> IdSet.empty let rec unpoly = function @@ -515,7 +527,7 @@ let rec unpoly = function | f -> f let rec is_polymorphic = function - | CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string -> false + | CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string -> false | CT_tup ctyps -> List.exists is_polymorphic ctyps | CT_enum _ -> false | CT_struct (_, ctors) | CT_variant (_, ctors) -> List.exists (fun (_, ctyp) -> is_polymorphic ctyp) ctors @@ -529,7 +541,7 @@ let pp_name id = string (string_of_name ~zencode:false id) let pp_ctyp ctyp = - string (full_string_of_ctyp ctyp |> Util.yellow |> Util.clear) + string (string_of_ctyp ctyp |> Util.yellow |> Util.clear) let pp_keyword str = string ((str |> Util.red |> Util.clear) ^ " ") @@ -539,6 +551,8 @@ let pp_cval (frag, ctyp) = let rec pp_clexp = function | CL_id (id, ctyp) -> pp_name id ^^ string " : " ^^ pp_ctyp ctyp + | CL_rmw (read, write, ctyp) -> + string "rmw" ^^ parens (pp_name read ^^ comma ^^ space ^^ pp_name write) ^^ string " : " ^^ pp_ctyp ctyp | CL_field (clexp, field) -> parens (pp_clexp clexp) ^^ string "." ^^ string field | CL_tuple (clexp, n) -> parens (pp_clexp clexp) ^^ string "." ^^ string (string_of_int n) | CL_addr clexp -> string "*" ^^ pp_clexp clexp @@ -643,7 +657,7 @@ let pp_cdef = function let rec fragment_deps = function | F_id id | F_ref id -> NameSet.singleton id | F_lit _ -> NameSet.empty - | F_field (frag, _) | F_unary (_, frag) | F_poly frag -> fragment_deps frag + | F_field (frag, _) | F_unary (_, frag) | F_poly frag | F_tuple_member (frag, _, _) -> fragment_deps frag | F_call (_, frags) -> List.fold_left NameSet.union NameSet.empty (List.map fragment_deps frags) | F_op (frag1, _, frag2) -> NameSet.union (fragment_deps frag1) (fragment_deps frag2) | F_ctor_kind (frag, _, _, _) -> fragment_deps frag @@ -653,11 +667,12 @@ let rec fragment_deps = function let cval_deps = function (frag, _) -> fragment_deps frag let rec clexp_deps = function - | CL_id (id, _) -> NameSet.singleton id + | CL_id (id, _) -> NameSet.empty, NameSet.singleton id + | CL_rmw (read, write, _) -> NameSet.singleton read, NameSet.singleton write | CL_field (clexp, _) -> clexp_deps clexp | CL_tuple (clexp, _) -> clexp_deps clexp | CL_addr clexp -> clexp_deps clexp - | CL_void -> NameSet.empty + | CL_void -> NameSet.empty, NameSet.empty (* Return the direct, read/write dependencies of a single instruction *) let instr_deps = function @@ -666,8 +681,12 @@ let instr_deps = function | I_init (ctyp, id, cval) | I_reinit (ctyp, id, cval) -> cval_deps cval, NameSet.singleton id | I_if (cval, _, _, _) -> cval_deps cval, NameSet.empty | I_jump (cval, label) -> cval_deps cval, NameSet.empty - | I_funcall (clexp, _, _, cvals) -> List.fold_left NameSet.union NameSet.empty (List.map cval_deps cvals), clexp_deps clexp - | I_copy (clexp, cval) -> cval_deps cval, clexp_deps clexp + | I_funcall (clexp, _, _, cvals) -> + let reads, writes = clexp_deps clexp in + List.fold_left NameSet.union reads (List.map cval_deps cvals), writes + | I_copy (clexp, cval) -> + let reads, writes = clexp_deps clexp in + NameSet.union reads (cval_deps cval), writes | I_clear (_, id) -> NameSet.singleton id, NameSet.empty | I_throw cval | I_return cval -> cval_deps cval, NameSet.empty | I_block _ | I_try_block _ -> NameSet.empty, NameSet.empty @@ -690,6 +709,7 @@ module NameCTMap = Map.Make(NameCT) let rec clexp_typed_writes = function | CL_id (id, ctyp) -> NameCTSet.singleton (id, ctyp) + | CL_rmw (_, id, ctyp) -> NameCTSet.singleton (id, ctyp) | CL_field (clexp, _) -> clexp_typed_writes clexp | CL_tuple (clexp, _) -> clexp_typed_writes clexp | CL_addr clexp -> clexp_typed_writes clexp @@ -704,6 +724,7 @@ let instr_typed_writes (I_aux (aux, _)) = let rec map_clexp_ctyp f = function | CL_id (id, ctyp) -> CL_id (id, f ctyp) + | CL_rmw (read, write, ctyp) -> CL_rmw (read, write, f ctyp) | CL_field (clexp, field) -> CL_field (map_clexp_ctyp f clexp, field) | CL_tuple (clexp, n) -> CL_tuple (map_clexp_ctyp f clexp, n) | CL_addr clexp -> CL_addr (map_clexp_ctyp f clexp) @@ -837,6 +858,7 @@ let cval_ctyp = function (_, ctyp) -> ctyp let rec clexp_ctyp = function | CL_id (_, ctyp) -> ctyp + | CL_rmw (_, _, ctyp) -> ctyp | CL_field (clexp, field) -> begin match clexp_ctyp clexp with | CT_struct (id, ctors) -> |
