summaryrefslogtreecommitdiff
path: root/src/jib
diff options
context:
space:
mode:
Diffstat (limited to 'src/jib')
-rw-r--r--src/jib/jib_compile.ml25
-rw-r--r--src/jib/jib_compile.mli6
-rw-r--r--src/jib/jib_optimize.ml17
-rw-r--r--src/jib/jib_smt.ml711
-rw-r--r--src/jib/jib_ssa.ml107
-rw-r--r--src/jib/jib_ssa.mli9
-rw-r--r--src/jib/jib_util.ml46
7 files changed, 873 insertions, 48 deletions
diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml
index 4a72ffff..13e7334a 100644
--- a/src/jib/jib_compile.ml
+++ b/src/jib/jib_compile.ml
@@ -151,7 +151,9 @@ type ctx =
letbinds : int list;
no_raw : bool;
convert_typ : ctx -> typ -> ctyp;
- optimize_anf : ctx -> typ aexp -> typ aexp
+ optimize_anf : ctx -> typ aexp -> typ aexp;
+ specialize_calls : bool;
+ ignore_64 : bool
}
let initial_ctx ~convert_typ:convert_typ ~optimize_anf:optimize_anf env =
@@ -164,7 +166,9 @@ let initial_ctx ~convert_typ:convert_typ ~optimize_anf:optimize_anf env =
letbinds = [];
no_raw = false;
convert_typ = convert_typ;
- optimize_anf = optimize_anf
+ optimize_anf = optimize_anf;
+ specialize_calls = false;
+ ignore_64 = false
}
let ctyp_of_typ ctx typ = ctx.convert_typ ctx typ
@@ -197,6 +201,9 @@ let rec compile_aval l ctx = function
| AV_lit (L_aux (L_string str, _), typ) ->
[], (F_lit (V_string (String.escaped str)), ctyp_of_typ ctx typ), []
+ | AV_lit (L_aux (L_num n, _), typ) when ctx.ignore_64 ->
+ [], (F_lit (V_int n), ctyp_of_typ ctx typ), []
+
| AV_lit (L_aux (L_num n, _), typ) when Big_int.less_equal (min_int 64) n && Big_int.less_equal n (max_int 64) ->
let gs = ngensym () in
[iinit CT_lint gs (F_lit (V_int n), CT_fint 64)],
@@ -258,7 +265,7 @@ let rec compile_aval l ctx = function
raise (Reporting.err_general l "Encountered empty vector literal")
(* Convert a small bitvector to a uint64_t literal. *)
- | AV_vector (avals, typ) when is_bitvector avals && List.length avals <= 64 ->
+ | AV_vector (avals, typ) when is_bitvector avals && (List.length avals <= 64 || ctx.ignore_64) ->
begin
let bitstring = F_lit (V_bits (List.map value_of_aval_bit avals)) in
let len = List.length avals in
@@ -383,7 +390,7 @@ let compile_funcall l ctx id args typ =
let have_ctyp = cval_ctyp cval in
if is_polymorphic ctyp then
(F_poly (fst cval), have_ctyp)
- else if ctyp_equal ctyp have_ctyp then
+ else if ctx.specialize_calls || ctyp_equal ctyp have_ctyp then
cval
else
let gs = ngensym () in
@@ -398,7 +405,7 @@ let compile_funcall l ctx id args typ =
List.rev !setup,
begin fun clexp ->
- if ctyp_equal (clexp_ctyp clexp) ret_ctyp then
+ if ctx.specialize_calls || ctyp_equal (clexp_ctyp clexp) ret_ctyp then
ifuncall clexp id setup_args
else
let gs = ngensym () in
@@ -450,7 +457,7 @@ let rec compile_match ctx (AP_aux (apat_aux, env, l)) cval case_label =
| AP_tup apats, (frag, ctyp) ->
begin
- let get_tup n ctyp = (F_field (frag, "ztup" ^ string_of_int n), ctyp) in
+ let get_tup n ctyp = (F_tuple_member (frag, List.length apats, n), ctyp) in
let fold (instrs, cleanup, n, ctx) apat ctyp =
let instrs', cleanup', ctx = compile_match ctx apat (get_tup n ctyp) case_label in
instrs @ instrs', cleanup' @ cleanup, n + 1, ctx
@@ -1039,12 +1046,16 @@ let fix_early_return ret instrs =
let end_function_label = label "end_function_" in
let is_return_recur (I_aux (instr, _)) =
match instr with
- | I_return _ | I_undefined _ | I_if _ | I_block _ -> true
+ | I_return _ | I_undefined _ | I_if _ | I_block _ | I_try_block _ -> true
| _ -> false
in
let rec rewrite_return historic instrs =
match instr_split_at is_return_recur instrs with
| instrs, [] -> instrs
+ | before, I_aux (I_try_block instrs, _) :: after ->
+ before
+ @ [itry_block (rewrite_return (historic @ before) instrs)]
+ @ rewrite_return (historic @ before) after
| before, I_aux (I_block instrs, _) :: after ->
before
@ [iblock (rewrite_return (historic @ before) instrs)]
diff --git a/src/jib/jib_compile.mli b/src/jib/jib_compile.mli
index a0cacc3c..12cd4a63 100644
--- a/src/jib/jib_compile.mli
+++ b/src/jib/jib_compile.mli
@@ -64,7 +64,7 @@ val opt_debug_flow_graphs : bool ref
val opt_debug_function : string ref
val ngensym : unit -> name
-
+
(** {2 Jib context} *)
(** Context for compiling Sail to Jib. We need to pass a (global)
@@ -82,7 +82,9 @@ type ctx =
letbinds : int list;
no_raw : bool;
convert_typ : ctx -> typ -> ctyp;
- optimize_anf : ctx -> typ aexp -> typ aexp
+ optimize_anf : ctx -> typ aexp -> typ aexp;
+ specialize_calls : bool;
+ ignore_64 : bool
}
val initial_ctx :
diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml
index 73b175a1..f9829dfd 100644
--- a/src/jib/jib_optimize.ml
+++ b/src/jib/jib_optimize.ml
@@ -160,6 +160,7 @@ let rec frag_subst id subst = function
| F_unary (op, frag) -> F_unary (op, frag_subst id subst frag)
| F_call (op, frags) -> F_call (op, List.map (frag_subst id subst) frags)
| F_field (frag, field) -> F_field (frag_subst id subst frag, field)
+ | F_tuple_member (frag, len, n) -> F_tuple_member (frag_subst id subst frag, len, n)
| F_raw str -> F_raw str
| F_ctor_kind (frag, ctor, unifiers, ctyp) -> F_ctor_kind (frag_subst id subst frag, ctor, unifiers, ctyp)
| F_ctor_unwrap (ctor, unifiers, frag) -> F_ctor_unwrap (ctor, unifiers, frag_subst id subst frag)
@@ -212,8 +213,10 @@ let rec instrs_subst id subst =
let rec clexp_subst id subst = function
| CL_id (id', ctyp) when Name.compare id id' = 0 ->
- assert (ctyp_equal ctyp (clexp_ctyp subst));
- subst
+ if ctyp_equal ctyp (clexp_ctyp subst) then
+ subst
+ else
+ subst
| CL_id (id', ctyp) -> CL_id (id', ctyp)
| CL_field (clexp, field) -> CL_field (clexp_subst id subst clexp, field)
| CL_addr clexp -> CL_addr (clexp_subst id subst clexp)
@@ -245,6 +248,15 @@ let inline cdefs should_inline instrs =
| instr -> instr
in
+ let fix_labels =
+ let fix_label l = "inline" ^ string_of_int !inlines ^ "_" ^ l in
+ function
+ | I_aux (I_goto label, aux) -> I_aux (I_goto (fix_label label), aux)
+ | I_aux (I_label label, aux) -> I_aux (I_label (fix_label label), aux)
+ | I_aux (I_jump (cval, label), aux) -> I_aux (I_jump (cval, fix_label label), aux)
+ | instr -> instr
+ in
+
let rec inline_instr = function
| I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline function_id ->
begin match find_function function_id cdefs with
@@ -252,6 +264,7 @@ let inline cdefs should_inline instrs =
incr inlines;
let inline_label = label "end_inline_" in
let body = List.fold_right2 instrs_subst (List.map name ids) (List.map fst args) body in
+ let body = List.map (map_instr fix_labels) body in
let body = List.map (map_instr (replace_end inline_label)) body in
let body = List.map (map_instr (replace_return clexp)) body in
I_aux (I_block (body @ [ilabel inline_label]), aux)
diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml
new file mode 100644
index 00000000..01240754
--- /dev/null
+++ b/src/jib/jib_smt.ml
@@ -0,0 +1,711 @@
+open Anf
+open Ast
+open Ast_util
+open Jib
+open Jib_util
+open Smtlib
+
+let zencode_upper_id id = Util.zencode_upper_string (string_of_id id)
+let zencode_id id = Util.zencode_string (string_of_id id)
+let zencode_name id = string_of_name ~deref_current_exception:false ~zencode:true id
+
+let lbits_index = ref 8
+
+let lbits_size () = Util.power 2 !lbits_index
+
+let lint_size = ref 64
+
+let smt_unit = mk_enum "Unit" ["Unit"]
+let smt_lbits = mk_record "Bits" [("size", Bitvec !lbits_index); ("bits", Bitvec (lbits_size ()))]
+
+let rec required_width n =
+ if Big_int.equal n Big_int.zero then
+ 1
+ else
+ 1 + required_width (Big_int.shift_right n 1)
+
+let rec smt_ctyp = function
+ | CT_constant n -> Bitvec (required_width n)
+ | CT_fint n -> Bitvec n
+ | CT_lint -> Bitvec !lint_size
+ | CT_unit -> smt_unit
+ | CT_bit -> Bitvec 1
+ | CT_fbits (n, _) -> Bitvec n
+ | CT_sbits (n, _) -> smt_lbits
+ | CT_lbits _ -> smt_lbits
+ | CT_bool -> Bool
+ | CT_enum (id, elems) ->
+ mk_enum (zencode_upper_id id) (List.map zencode_id elems)
+ | CT_struct (id, fields) ->
+ mk_record (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctyp)) fields)
+ | CT_variant (id, ctors) ->
+ mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctyp)) ctors)
+ | CT_tup ctyps -> Tuple (List.map smt_ctyp ctyps)
+ | CT_vector (_, ctyp) -> Array (Bitvec 8, smt_ctyp ctyp)
+ | CT_string -> Bitvec 64
+ | ctyp -> failwith ("Unhandled ctyp: " ^ string_of_ctyp ctyp)
+
+let bvint sz x =
+ if sz mod 4 = 0 then
+ let hex = Printf.sprintf "%X" x in
+ let padding = String.make (sz / 4 - String.length hex) '0' in
+ Hex (padding ^ hex)
+ else
+ failwith "Bad len"
+
+let smt_value vl ctyp =
+ let open Value2 in
+ match vl, ctyp with
+ | V_bits bs, _ ->
+ begin match Sail2_values.hexstring_of_bits bs with
+ | Some s -> Hex (Xstring.implode s)
+ | None -> Bin (Xstring.implode (List.map Sail2_values.bitU_char bs))
+ end
+ | V_bool b, _ -> Bool_lit b
+ | V_int n, CT_constant m -> bvint (required_width n) (Big_int.to_int n)
+ | V_int n, CT_fint sz -> bvint sz (Big_int.to_int n)
+ | V_bit Sail2_values.B0, CT_bit -> Bin "0"
+ | V_bit Sail2_values.B1, CT_bit -> Bin "1"
+ | V_unit, _ -> Var "unit"
+
+ | vl, _ -> failwith ("Bad literal " ^ string_of_value vl)
+
+let zencode_ctor ctor_id unifiers =
+ match unifiers with
+ | [] ->
+ zencode_id ctor_id
+ | _ ->
+ Util.zencode_string (string_of_id ctor_id ^ "_" ^ Util.string_of_list "_" string_of_ctyp unifiers)
+
+let rec smt_cval env cval =
+ match cval with
+ | F_lit vl, ctyp -> smt_value vl ctyp
+ | frag, _ -> smt_fragment env frag
+
+and smt_fragment env frag =
+ match frag with
+ | F_id (Name (id, _) as ssa_id) ->
+ begin match Type_check.Env.lookup_id id env with
+ | Enum _ -> Var (zencode_id id)
+ | _ -> Var (zencode_name ssa_id)
+ end
+ | F_id ssa_id -> Var (zencode_name ssa_id)
+ | F_op (frag1, "!=", frag2) ->
+ Fn ("not", [Fn ("=", [smt_fragment env frag1; smt_fragment env frag2])])
+ | F_unary ("!", frag) ->
+ Fn ("not", [smt_cval env (frag, CT_bool)])
+ | F_ctor_kind (union, ctor_id, unifiers, _) ->
+ Fn ("not", [Tester (zencode_ctor ctor_id unifiers, smt_fragment env union)])
+ | F_ctor_unwrap (ctor_id, unifiers, union) ->
+ Fn ("un" ^ zencode_ctor ctor_id unifiers, [smt_fragment env union])
+ | F_field (union, field) ->
+ Fn ("un" ^ field, [smt_fragment env union])
+ | F_tuple_member (frag, len, n) ->
+ Fn (Printf.sprintf "tup_%d_%d" len n, [smt_fragment env frag])
+ | frag -> failwith ("Unrecognised fragment " ^ string_of_fragment ~zencode:false frag)
+
+let builtin_zeros env cval = function
+ | CT_fbits (n, _) -> bvzero n
+ | CT_lbits _ -> Fn ("Bits", [extract (!lbits_index - 1) 0 (smt_cval env cval); bvzero (lbits_size ())])
+ | _ -> failwith "Cannot compile zeros"
+
+let builtin_zero_extend env vbits vlen ret_ctyp =
+ match cval_ctyp vbits, ret_ctyp with
+ | CT_fbits (n, _), CT_fbits (m, _) when n = m ->
+ smt_cval env vbits
+ | CT_fbits (n, _), CT_fbits (m, _) ->
+ let bv = smt_cval env vbits in
+ Fn ("concat", [bvzero (m - n); bv])
+ | _ -> failwith "Cannot compile zero_extend"
+
+let builtin_sign_extend env vbits vlen ret_ctyp =
+ match cval_ctyp vbits, ret_ctyp with
+ | CT_fbits (n, _), CT_fbits (m, _) when n = m ->
+ smt_cval env vbits
+ | CT_fbits (n, _), CT_fbits (m, _) ->
+ let bv = smt_cval env vbits in
+ let top_bit_one = Fn ("=", [Extract (n - 1, n - 1, bv); Bin "1"]) in
+ Ite (top_bit_one, Fn ("concat", [bvones (m - n); bv]), Fn ("concat", [bvzero (m - n); bv]))
+ | _ -> failwith "Cannot compile zero_extend"
+
+let int_size = function
+ | CT_constant n -> required_width n
+ | CT_fint sz -> sz
+ | CT_lint -> lbits_size ()
+ | _ -> failwith "Argument to int_size must be an integer"
+
+(* [bvzeint esz cval] (BitVector Zero Extend INTeger), takes a cval
+ which must be an integer type (either CT_fint, or CT_lint), and
+ produces a bitvector which is either zero extended or truncated to
+ exactly esz bits. *)
+let bvzeint env esz cval =
+ let sz = int_size (cval_ctyp cval) in
+ match fst cval with
+ | F_lit (V_int n) ->
+ bvint esz (Big_int.to_int n)
+ | _ ->
+ let smt = smt_cval env cval in
+ if esz = sz then
+ smt
+ else if esz > sz then
+ Fn ("concat", [bvzero (esz - sz); smt])
+ else
+ Extract (esz - 1, 0, smt)
+
+let builtin_shift shiftop env vbits vshift ret_ctyp =
+ match cval_ctyp vbits with
+ | CT_fbits (n, _) ->
+ let bv = smt_cval env vbits in
+ let len = bvzeint env n vshift in
+ Fn (shiftop, [bv; len])
+
+ | CT_lbits _ ->
+ let bv = smt_cval env vbits in
+ let shift = bvzeint env (lbits_size ()) vshift in
+ Fn ("Bits", [Fn ("len", [bv]); Fn (shiftop, [Fn ("contents", [bv]); shift])])
+
+ | _ -> failwith ("Cannot compile shift: " ^ shiftop)
+
+let builtin_or_bits env v1 v2 ret_ctyp =
+ match cval_ctyp v1, cval_ctyp v2 with
+ | CT_fbits (n, _), CT_fbits (m, _) ->
+ assert (n = m);
+ let smt1 = smt_cval env v1 in
+ let smt2 = smt_cval env v2 in
+ bvor smt1 smt2
+
+ | CT_lbits _, CT_lbits _ ->
+ let smt1 = smt_cval env v1 in
+ let smt2 = smt_cval env v2 in
+ Fn ("Bits", [Fn ("len", [smt1]); bvor (Fn ("contents", [smt1])) (Fn ("contents", [smt2]))])
+
+ | _ -> failwith "Cannot compile or_bits"
+
+let builtin_append env v1 v2 ret_ctyp =
+ match cval_ctyp v1, cval_ctyp v2, ret_ctyp with
+ | CT_fbits (n, _), CT_fbits (m, _), CT_fbits (o, _) ->
+ assert (n + m = o);
+ let smt1 = smt_cval env v1 in
+ let smt2 = smt_cval env v2 in
+ Fn ("concat", [smt1; smt2])
+
+ | CT_fbits (n, _), CT_lbits _, CT_lbits _ ->
+ let smt1 = smt_cval env v1 in
+ let smt2 = smt_cval env v2 in
+ let x = Fn ("concat", [bvzero (lbits_size () - n); smt1]) in
+ let shift = Fn ("concat", [bvzero (lbits_size () - !lbits_index); Fn ("len", [smt2])]) in
+ Fn ("Bits", [bvadd (bvint !lbits_index n) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))])
+
+ | CT_lbits _, CT_lbits _, CT_lbits _ ->
+ let smt1 = smt_cval env v1 in
+ let smt2 = smt_cval env v2 in
+ let x = Fn ("contents", [smt1]) in
+ let shift = Fn ("concat", [bvzero (lbits_size () - !lbits_index); Fn ("len", [smt2])]) in
+ Fn ("Bits", [bvadd (Fn ("len", [smt1])) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))])
+
+ | _ -> failwith "Cannot compile append"
+
+let builtin_length env v ret_ctyp =
+ match cval_ctyp v, ret_ctyp with
+ | CT_lbits _, CT_fint m ->
+ let sz = !lbits_index in
+ let len = Fn ("len", [smt_cval env v]) in
+ if m = sz then
+ len
+ else if m > sz then
+ Fn ("concat", [bvzero (m - sz); len])
+ else
+ Extract (m - 1, 0, len)
+
+ | _, _ -> failwith "Cannot compile length"
+
+let builtin_vector_subrange env vec i j ret_ctyp =
+ match cval_ctyp vec, cval_ctyp i, cval_ctyp j with
+ | CT_fbits (n, _), CT_constant i, CT_constant j ->
+ Extract (Big_int.to_int i, Big_int.to_int j, smt_cval env vec)
+
+ | _ -> failwith "Cannot compile vector subrange"
+
+let builtin_vector_access env vec i ret_ctyp =
+ match cval_ctyp vec, cval_ctyp i, ret_ctyp with
+ | CT_fbits (n, _), CT_constant i, CT_bit ->
+ Extract (Big_int.to_int i, Big_int.to_int i, smt_cval env vec)
+
+ | _ -> failwith "Cannot compile vector subrange"
+
+let builtin_unsigned env v ret_ctyp =
+ match cval_ctyp v, ret_ctyp with
+ | CT_fbits (n, _), CT_fint m ->
+ let smt = smt_cval env v in
+ Fn ("concat", [bvzero (m - n); smt])
+
+ | _ -> failwith "Cannot compile unsigned"
+
+let builtin_eq_int env v1 v2 =
+ match cval_ctyp v1, cval_ctyp v2 with
+ | CT_fint m, CT_constant c ->
+ Fn ("=", [smt_cval env v1; bvint m (Big_int.to_int c)])
+
+ | _ -> failwith "Cannot compile eq_int"
+
+let builtin_add_bits env v1 v2 ret_ctyp =
+ match cval_ctyp v1, cval_ctyp v2, ret_ctyp with
+ | CT_fbits (n, _), CT_fbits (m, _), CT_fbits (o, _) ->
+ assert (n = m && m = o);
+ Fn ("bvadd", [smt_cval env v1; smt_cval env v2])
+
+ | _ -> failwith "Cannot compile add_bits"
+
+let smt_primop env name args ret_ctyp =
+ match name, args, ret_ctyp with
+ | "eq_bits", [v1; v2], _ ->
+ let smt1 = smt_cval env v1 in
+ let smt2 = smt_cval env v2 in
+ Fn ("=", [smt1; smt2])
+ | "eq_bit", [v1; v2], _ ->
+ let smt1 = smt_cval env v1 in
+ let smt2 = smt_cval env v2 in
+ Fn ("=", [smt1; smt2])
+
+ | "not", [v], _ -> Fn ("not", [smt_cval env v])
+
+ | "zeros", [v1], _ -> builtin_zeros env v1 ret_ctyp
+ | "zero_extend", [v1; v2], _ -> builtin_zero_extend env v1 v2 ret_ctyp
+ | "sign_extend", [v1; v2], _ -> builtin_sign_extend env v1 v2 ret_ctyp
+ | "shiftl", [v1; v2], _ -> builtin_shift "bvshl" env v1 v2 ret_ctyp
+ | "shiftr", [v1; v2], _ -> builtin_shift "bvlshr" env v1 v2 ret_ctyp
+ | "or_bits", [v1; v2], _ -> builtin_or_bits env v1 v2 ret_ctyp
+ | "add_bits", [v1; v2], _ -> builtin_add_bits env v1 v2 ret_ctyp
+ | "append", [v1; v2], _ -> builtin_append env v1 v2 ret_ctyp
+ | "length", [v], ret_ctyp -> builtin_length env v ret_ctyp
+ | "vector_access", [v1; v2], ret_ctyp -> builtin_vector_access env v1 v2 ret_ctyp
+ | "vector_subrange", [v1; v2; v3], ret_ctyp -> builtin_vector_subrange env v1 v2 v3 ret_ctyp
+ | "sail_unsigned", [v], ret_ctyp -> builtin_unsigned env v ret_ctyp
+ | "eq_int", [v1; v2], _ -> builtin_eq_int env v1 v2
+
+ | _ -> failwith ("Bad primop " ^ name ^ " " ^ Util.string_of_list ", " string_of_ctyp (List.map snd args) ^ " -> " ^ string_of_ctyp ret_ctyp)
+
+let rec smt_conversion from_ctyp to_ctyp x =
+ match from_ctyp, to_ctyp with
+ | _, _ when ctyp_equal from_ctyp to_ctyp -> x
+ | _, _ -> failwith "Bad conversion"
+
+let define_const id ctyp exp = Define_const (zencode_name id, smt_ctyp ctyp, exp)
+let declare_const id ctyp = Declare_const (zencode_name id, smt_ctyp ctyp)
+
+let smt_ctype_def = function
+ | CTD_enum (id, elems) ->
+ [declare_datatypes (mk_enum (zencode_upper_id id) (List.map zencode_id elems))]
+
+ | CTD_struct (id, fields) ->
+ [declare_datatypes
+ (mk_record (zencode_upper_id id)
+ (List.map (fun (field, ctyp) -> zencode_id field, smt_ctyp ctyp) fields))]
+
+ | CTD_variant (id, ctors) ->
+ [declare_datatypes
+ (mk_variant (zencode_upper_id id)
+ (List.map (fun (ctor, ctyp) -> zencode_id ctor, smt_ctyp ctyp) ctors))]
+
+let rec generate_ctype_defs = function
+ | CDEF_type ctd :: cdefs -> smt_ctype_def ctd :: generate_ctype_defs cdefs
+ | _ :: cdefs -> generate_ctype_defs cdefs
+ | [] -> []
+
+let rec generate_reg_decs inits = function
+ | CDEF_reg_dec (id, ctyp, _) :: cdefs when not (NameMap.mem (Name (id, 0)) inits)->
+ Declare_const (zencode_name (Name (id, 0)), smt_ctyp ctyp)
+ :: generate_reg_decs inits cdefs
+ | _ :: cdefs -> generate_reg_decs inits cdefs
+ | [] -> []
+
+(**************************************************************************)
+(* 2. Converting sail types to Jib types for SMT *)
+(**************************************************************************)
+
+let max_int n = Big_int.pred (Big_int.pow_int_positive 2 (n - 1))
+let min_int n = Big_int.negate (Big_int.pow_int_positive 2 (n - 1))
+
+(** Convert a sail type into a C-type. This function can be quite
+ slow, because it uses ctx.local_env and SMT to analyse the Sail
+ types and attempts to fit them into the smallest possible C
+ types, provided ctx.optimize_smt is true (default) **)
+let rec ctyp_of_typ ctx typ =
+ let open Ast in
+ let open Type_check in
+ let open Jib_compile in
+ let Typ_aux (typ_aux, l) as typ = Env.expand_synonyms ctx.tc_env typ in
+ match typ_aux with
+ | Typ_id id when string_of_id id = "bit" -> CT_bit
+ | Typ_id id when string_of_id id = "bool" -> CT_bool
+ | Typ_id id when string_of_id id = "int" -> CT_lint
+ | Typ_id id when string_of_id id = "nat" -> CT_lint
+ | Typ_id id when string_of_id id = "unit" -> CT_unit
+ | Typ_id id when string_of_id id = "string" -> CT_string
+ | Typ_id id when string_of_id id = "real" -> CT_real
+
+ | Typ_app (id, _) when string_of_id id = "atom_bool" -> CT_bool
+
+ | Typ_app (id, args) when string_of_id id = "itself" ->
+ ctyp_of_typ ctx (Typ_aux (Typ_app (mk_id "atom", args), l))
+ | Typ_app (id, _) when string_of_id id = "range" || string_of_id id = "atom" || string_of_id id = "implicit" ->
+ begin match destruct_range Env.empty typ with
+ | None -> assert false (* Checked if range type in guard *)
+ | Some (kids, constr, n, m) ->
+ let ctx = { ctx with local_env = add_existential Parse_ast.Unknown (List.map (mk_kopt K_int) kids) constr ctx.local_env } in
+ match nexp_simp n, nexp_simp m with
+ | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _)
+ when n = m ->
+ CT_constant n
+ | Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_constant m, _)
+ when Big_int.less_equal (min_int 64) n && Big_int.less_equal m (max_int 64) ->
+ CT_fint 64
+ | n, m ->
+ if prove __POS__ ctx.local_env (nc_lteq (nconstant (min_int 64)) n) && prove __POS__ ctx.local_env (nc_lteq m (nconstant (max_int 64))) then
+ CT_fint 64
+ else
+ CT_lint
+ end
+
+ | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "list" ->
+ CT_list (ctyp_of_typ ctx typ)
+
+ (* When converting a sail bitvector type into C, we have three options in order of efficiency:
+ - If the length is obviously static and smaller than 64, use the fixed bits type (aka uint64_t), fbits.
+ - If the length is less than 64, then use a small bits type, sbits.
+ - If the length may be larger than 64, use a large bits type lbits. *)
+ | Typ_app (id, [A_aux (A_nexp n, _);
+ A_aux (A_order ord, _);
+ A_aux (A_typ (Typ_aux (Typ_id vtyp_id, _)), _)])
+ when string_of_id id = "vector" && string_of_id vtyp_id = "bit" ->
+ let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in
+ begin match nexp_simp n with
+ | Nexp_aux (Nexp_constant n, _) -> CT_fbits (Big_int.to_int n, direction)
+ | _ -> CT_lbits direction
+ end
+
+ | Typ_app (id, [A_aux (A_nexp n, _);
+ A_aux (A_order ord, _);
+ A_aux (A_typ typ, _)])
+ when string_of_id id = "vector" ->
+ let direction = match ord with Ord_aux (Ord_dec, _) -> true | Ord_aux (Ord_inc, _) -> false | _ -> assert false in
+ CT_vector (direction, ctyp_of_typ ctx typ)
+
+ | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" ->
+ CT_ref (ctyp_of_typ ctx typ)
+
+ | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> Bindings.bindings)
+ | Typ_id id | Typ_app (id, _) when Bindings.mem id ctx.variants -> CT_variant (id, Bindings.find id ctx.variants |> Bindings.bindings)
+ | Typ_id id when Bindings.mem id ctx.enums -> CT_enum (id, Bindings.find id ctx.enums |> IdSet.elements)
+
+ | Typ_tup typs -> CT_tup (List.map (ctyp_of_typ ctx) typs)
+
+ | Typ_exist _ ->
+ (* Use Type_check.destruct_exist when optimising with SMT, to
+ ensure that we don't cause any type variable clashes in
+ local_env, and that we can optimize the existential based upon
+ it's constraints. *)
+ begin match destruct_exist (Env.expand_synonyms ctx.local_env typ) with
+ | Some (kids, nc, typ) ->
+ let env = add_existential l kids nc ctx.local_env in
+ ctyp_of_typ { ctx with local_env = env } typ
+ | None -> raise (Reporting.err_unreachable l __POS__ "Existential cannot be destructured!")
+ end
+
+ | Typ_var kid -> CT_poly
+
+ | _ -> raise (Reporting.err_unreachable l __POS__ ("No C type for type " ^ string_of_typ typ))
+
+(**************************************************************************)
+(* 3. Optimization of primitives and literals *)
+(**************************************************************************)
+
+let hex_char =
+ let open Sail2_values in
+ function
+ | '0' -> [B0; B0; B0; B0]
+ | '1' -> [B0; B0; B0; B1]
+ | '2' -> [B0; B0; B1; B0]
+ | '3' -> [B0; B0; B1; B1]
+ | '4' -> [B0; B1; B0; B0]
+ | '5' -> [B0; B1; B0; B1]
+ | '6' -> [B0; B1; B1; B0]
+ | '7' -> [B0; B1; B1; B1]
+ | '8' -> [B1; B0; B0; B0]
+ | '9' -> [B1; B0; B0; B1]
+ | 'A' | 'a' -> [B1; B0; B1; B0]
+ | 'B' | 'b' -> [B1; B0; B1; B1]
+ | 'C' | 'c' -> [B1; B1; B0; B0]
+ | 'D' | 'd' -> [B1; B1; B0; B1]
+ | 'E' | 'e' -> [B1; B1; B1; B0]
+ | 'F' | 'f' -> [B1; B1; B1; B1]
+ | _ -> failwith "Invalid hex character"
+
+let literal_to_fragment (L_aux (l_aux, _) as lit) =
+ match l_aux with
+ | L_num n -> Some (F_lit (V_int n), CT_constant n)
+ | L_hex str when String.length str <= 16 ->
+ let content = Util.string_to_list str |> List.map hex_char |> List.concat in
+ Some (F_lit (V_bits content), CT_fbits (String.length str * 4, true))
+ | L_unit -> Some (F_lit V_unit, CT_unit)
+ | L_true -> Some (F_lit (V_bool true), CT_bool)
+ | L_false -> Some (F_lit (V_bool false), CT_bool)
+ | _ -> None
+
+let c_literals ctx =
+ let rec c_literal env l = function
+ | AV_lit (lit, typ) as v ->
+ begin match literal_to_fragment lit with
+ | Some (frag, ctyp) -> AV_C_fragment (frag, typ, ctyp)
+ | None -> v
+ end
+ | AV_tuple avals -> AV_tuple (List.map (c_literal env l) avals)
+ | v -> v
+ in
+ map_aval c_literal
+
+(**************************************************************************)
+(* 3. Generating SMT *)
+(**************************************************************************)
+
+(* When generating SMT when we encounter joins between two or more
+ blocks such as in the example below, we have to generate a muxer
+ that chooses the correct value of v_n or v_m to assign to v_o. We
+ use the pi nodes that contain the global path condition for each
+ block to generate an if-then-else for each phi function. The order
+ of the arguments to each phi function is based on the graph node
+ index for the predecessor nodes.
+
+ +---------------+ +---------------+
+ | pi(cond_1) | | pi(cond_2) |
+ | ... | | ... |
+ | Basic block 1 | | Basic block 2 |
+ +---------------+ +---------------+
+ \ /
+ \ /
+ +---------------------+
+ | v/o = phi(v/n, v/m) |
+ | ... |
+ +---------------------+
+
+ would generate:
+
+ (define-const v/o (ite cond_1 v/n v/m_))
+*)
+let smt_ssanode env cfg preds =
+ let open Jib_ssa in
+ function
+ | Pi _ -> []
+ | Phi (id, ctyp, ids) ->
+ let get_pi n =
+ match get_vertex cfg n with
+ | Some ((ssanodes, _), _, _) ->
+ List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes)
+ | None -> failwith "Predecessor node does not exist"
+ in
+ let pis = List.map get_pi (IntSet.elements preds) in
+ let mux =
+ List.fold_right2 (fun pi id chain ->
+ let pathcond =
+ match pi with
+ | [cval] -> smt_cval env cval
+ | _ -> Fn ("and", List.map (smt_cval env) pi)
+ in
+ match chain with
+ | Some smt ->
+ Some (Ite (pathcond, Var (zencode_name id), smt))
+ | None ->
+ Some (Var (zencode_name id)))
+ pis ids None
+ in
+ match mux with
+ | None -> []
+ | Some mux ->
+ [Define_const (zencode_name id, smt_ctyp ctyp, mux)]
+
+(* For any complex l-expression we need to turn it into a
+ read-modify-write in the SMT solver. The SSA transform turns CL_id
+ nodes into CL_rmw (read, write, ctyp) nodes when CL_id is wrapped
+ in any other l-expression. The read and write must have the same
+ name but different SSA numbers.
+*)
+let rec rmw_write = function
+ | CL_rmw (_, write, ctyp) -> write, ctyp
+ | CL_id _ -> assert false
+ | CL_tuple (clexp, _) -> rmw_write clexp
+ | clexp ->
+ failwith (Pretty_print_sail.to_string (pp_clexp clexp))
+
+let rmw_read = function
+ | CL_rmw (read, _, _) -> zencode_name read
+ | _ -> assert false
+
+let rmw_modify smt = function
+ | CL_tuple (clexp, n) ->
+ let ctyp = clexp_ctyp clexp in
+ begin match ctyp with
+ | CT_tup ctyps ->
+ let len = List.length ctyps in
+ let set_tup i =
+ if i == n then
+ smt
+ else
+ Fn (Printf.sprintf "tup_%d_%d" len i, [Var (rmw_read clexp)])
+ in
+ Fn ("tup4", List.init len set_tup)
+ | _ ->
+ failwith "Tuple modify does not have tuple type"
+ end
+ | _ -> assert false
+
+(* For a basic block (contained in a control-flow node / cfnode), we
+ turn the instructions into a sequence of define-const and
+ declare-const expressions. Because we are working with a SSA graph,
+ each variable is guaranteed to only be declared once.
+*)
+let smt_instr env =
+ let open Type_check in
+ function
+ | I_aux (I_funcall (CL_id (id, ret_ctyp), _, function_id, args), _) ->
+ if Env.is_extern function_id env "c" then
+ let name = Env.get_extern function_id env "c" in
+ let value = smt_primop env name args ret_ctyp in
+ [define_const id ret_ctyp value]
+ else
+ let smt_args = List.map (smt_cval env) args in
+ [define_const id ret_ctyp (Fn (zencode_id function_id, smt_args))]
+
+ | I_aux (I_init (ctyp, id, cval), _) | I_aux (I_copy (CL_id (id, ctyp), cval), _) ->
+ [define_const id ctyp
+ (smt_conversion (cval_ctyp cval) ctyp (smt_cval env cval))]
+
+ | I_aux (I_copy (clexp, cval), _) ->
+ let smt = smt_cval env cval in
+ let write, ctyp = rmw_write clexp in
+ [define_const write ctyp (rmw_modify smt clexp)]
+
+ | I_aux (I_decl (ctyp, id), _) ->
+ [declare_const id ctyp]
+
+ | I_aux (I_end id, _) ->
+ [Assert (Var (zencode_name id))]
+
+ | I_aux (I_clear _, _) -> []
+
+ | I_aux (I_match_failure, _) -> []
+
+ | instr ->
+ failwith ("Cannot translate: " ^ Pretty_print_sail.to_string (pp_instr instr))
+
+let smt_cfnode all_cdefs env =
+ let open Jib_ssa in
+ function
+ | CF_start inits ->
+ let smt_reg_decs = generate_reg_decs inits all_cdefs in
+ let smt_start (id, ctyp) =
+ match id with
+ | Have_exception _ -> define_const id ctyp (Bool_lit false)
+ | _ -> declare_const id ctyp
+ in
+ smt_reg_decs @ List.map smt_start (NameMap.bindings inits)
+ | CF_block instrs ->
+ List.concat (List.map (smt_instr env) instrs)
+ (* We can ignore any non basic-block/start control-flow nodes *)
+ | _ -> []
+
+let rec find_function id = function
+ | CDEF_fundef (id', heap_return, args, body) :: _ when Id.compare id id' = 0 ->
+ Some (heap_return, args, body)
+ | _ :: cdefs ->
+ find_function id cdefs
+ | [] -> None
+
+let smt_cdef out_chan env all_cdefs = function
+ | CDEF_spec (function_id, arg_ctyps, ret_ctyp)
+ when string_of_id function_id = "check_sat" ->
+ begin match find_function function_id all_cdefs with
+ | Some (None, args, instrs) ->
+ let open Jib_ssa in
+ let smt_args =
+ List.map2 (fun id ctyp -> declare_const (Name (id, 0)) ctyp) args arg_ctyps
+ in
+ output_smt_defs out_chan smt_args;
+
+ let instrs =
+ let open Jib_optimize in
+ instrs
+ (* |> optimize_unit *)
+ |> inline all_cdefs (fun _ -> true)
+ |> flatten_instrs
+ in
+
+ let str = Pretty_print_sail.to_string PPrint.(separate_map hardline Jib_util.pp_instr instrs) in
+ prerr_endline str;
+
+ let start, cfg = ssa instrs in
+ let chan = open_out "smt_ssa.gv" in
+ make_dot chan cfg;
+ close_out chan;
+
+ let visit_order = topsort cfg in
+
+ List.iter (fun n ->
+ begin match get_vertex cfg n with
+ | None -> ()
+ | Some ((ssanodes, cfnode), preds, succs) ->
+ let muxers =
+ ssanodes |> List.map (smt_ssanode env cfg preds) |> List.concat
+ in
+ let basic_block = smt_cfnode all_cdefs env cfnode in
+ output_smt_defs out_chan muxers;
+ output_smt_defs out_chan basic_block;
+ end
+ ) visit_order
+
+ | _ -> failwith "Bad function body"
+ end
+ | _ -> ()
+
+let rec smt_cdefs out_chan env ast =
+ function
+ | cdef :: cdefs ->
+ smt_cdef out_chan env ast cdef;
+ smt_cdefs out_chan env ast cdefs
+ | [] -> ()
+
+let generate_smt out_chan env ast =
+ try
+ let open Jib_compile in
+ let ctx =
+ initial_ctx
+ ~convert_typ:ctyp_of_typ
+ ~optimize_anf:(fun ctx aexp -> c_literals ctx aexp)
+ env
+ in
+ let t = Profile.start () in
+ let cdefs, ctx = compile_ast { ctx with specialize_calls = true; ignore_64 = true } ast in
+ Profile.finish "Compiling to Jib IR" t;
+
+ (* output_string out_chan "(set-option :produce-models true)\n"; *)
+ output_string out_chan "(set-logic QF_AUFBVDT)\n";
+ output_smt_defs out_chan
+ [declare_datatypes (mk_enum "Unit" ["unit"]);
+ Declare_tuple 2;
+ Declare_tuple 3;
+ Declare_tuple 4;
+ Declare_tuple 5;
+ declare_datatypes (mk_record "Bits" [("len", Bitvec !lbits_index);
+ ("contents", Bitvec (lbits_size ()))])
+
+ ];
+
+ let smt_type_defs = List.concat (generate_ctype_defs cdefs) in
+ output_string out_chan "\n; Sail type definitions\n";
+ output_smt_defs out_chan smt_type_defs;
+
+ output_string out_chan "\n; Sail function\n";
+ smt_cdefs out_chan env cdefs cdefs;
+ output_string out_chan "(check-sat)\n"
+ with
+ | Type_check.Type_error (_, l, err) ->
+ raise (Reporting.err_typ l (Type_error.string_of_type_error err));
diff --git a/src/jib/jib_ssa.ml b/src/jib/jib_ssa.ml
index a086f0b9..e79c88a1 100644
--- a/src/jib/jib_ssa.ml
+++ b/src/jib/jib_ssa.ml
@@ -76,7 +76,7 @@ let iter_graph f graph =
| Some (x, y, z) -> f x y z
| None -> ()
done
-
+
(** Add a vertex to a graph, returning the node index *)
let add_vertex data graph =
let n = graph.next in
@@ -124,6 +124,52 @@ let reachable roots graph =
in
IntSet.iter reachable' roots; !visited
+let topsort graph =
+ let marked = ref IntSet.empty in
+ let temp_marked = ref IntSet.empty in
+ let list = ref [] in
+
+ let rec visit node =
+ if IntSet.mem node !temp_marked then
+ failwith "Not a DAG"
+ else if IntSet.mem node !marked then
+ ()
+ else
+ begin match get_vertex graph node with
+ | None -> failwith "Node does not exist in topsort"
+ | Some (_, _, succs) ->
+ temp_marked := IntSet.add node !temp_marked;
+ IntSet.iter visit succs;
+ marked := IntSet.add node !marked;
+ temp_marked := IntSet.remove node !temp_marked;
+ list := node :: !list
+ end
+ in
+
+ let find_unmarked () =
+ let unmarked = ref (-1) in
+ let i = ref 0 in
+ while !unmarked = -1 && !i < Array.length graph.nodes do
+ begin match get_vertex graph !i with
+ | None -> ()
+ | Some _ ->
+ if not (IntSet.mem !i !marked) then
+ unmarked := !i
+ end;
+ incr i
+ done;
+ !unmarked
+ in
+
+ let rec topsort' () =
+ let unmarked = find_unmarked () in
+ if unmarked = -1 then
+ ()
+ else
+ (visit unmarked; topsort' ())
+ in
+ topsort' (); !list
+
let prune visited graph =
for i = 0 to graph.next - 1 do
match graph.nodes.(i) with
@@ -143,7 +189,7 @@ type cf_node =
| CF_label of string
| CF_block of instr list
| CF_guard of cval
- | CF_start
+ | CF_start of ctyp NameMap.t
let cval_not (f, ctyp) = (F_unary ("!", f), ctyp)
@@ -212,7 +258,7 @@ let control_flow_graph instrs =
| [] -> preds
in
- let start = add_vertex ([], CF_start) graph in
+ let start = add_vertex ([], CF_start NameMap.empty) graph in
let finish = cfg [start] instrs in
let visited = reachable (IntSet.singleton start) graph in
@@ -426,23 +472,28 @@ let rename_variables graph root children =
let counts = ref NameMap.empty in
let stacks = ref NameMap.empty in
+ let phi_zeros = ref NameMap.empty in
+
+ let ssa_name i = function
+ | Name (id, _) -> Name (id, i)
+ | Have_exception _ -> Have_exception i
+ | Current_exception _ -> Current_exception i
+ | Return _ -> Return i
+ in
+
let get_count id =
match NameMap.find_opt id !counts with Some n -> n | None -> 0
in
let top_stack id =
- match NameMap.find_opt id !stacks with Some (x :: _) -> x | (Some [] | None) -> 0
+ match NameMap.find_opt id !stacks with Some (x :: _) -> x | Some [] -> 0 | None -> 0
+ in
+ let top_stack_phi id ctyp =
+ match NameMap.find_opt id !stacks with Some (x :: _) -> x | Some [] -> 0 | None -> (phi_zeros := NameMap.add (ssa_name 0 id) ctyp !phi_zeros; 0)
in
let push_stack id n =
stacks := NameMap.add id (n :: match NameMap.find_opt id !stacks with Some s -> s | None -> []) !stacks
in
- let ssa_name i = function
- | Name (id, _) -> Name (id, i)
- | Have_exception _ -> Have_exception i
- | Current_exception _ -> Current_exception i
- | Return _ -> Return i
- in
-
let rec fold_frag = function
| F_id id ->
let i = top_stack id in
@@ -455,21 +506,29 @@ let rename_variables graph root children =
| F_unary (op, f) -> F_unary (op, fold_frag f)
| F_call (id, fs) -> F_call (id, List.map fold_frag fs)
| F_field (f, field) -> F_field (fold_frag f, field)
+ | F_tuple_member (f, len, n) -> F_tuple_member (fold_frag f, len, n)
| F_raw str -> F_raw str
| F_ctor_kind (f, ctor, unifiers, ctyp) -> F_ctor_kind (fold_frag f, ctor, unifiers, ctyp)
| F_ctor_unwrap (ctor, unifiers, f) -> F_ctor_unwrap (ctor, unifiers, fold_frag f)
| F_poly f -> F_poly (fold_frag f)
in
- let rec fold_clexp = function
+ let rec fold_clexp rmw = function
+ | CL_id (id, ctyp) when rmw ->
+ let i = top_stack id in
+ let j = get_count id + 1 in
+ counts := NameMap.add id j !counts;
+ push_stack id j;
+ CL_rmw (ssa_name i id, ssa_name j id, ctyp)
| CL_id (id, ctyp) ->
let i = get_count id + 1 in
counts := NameMap.add id i !counts;
push_stack id i;
CL_id (ssa_name i id, ctyp)
- | CL_field (clexp, field) -> CL_field (fold_clexp clexp, field)
- | CL_addr clexp -> CL_addr (fold_clexp clexp)
- | CL_tuple (clexp, n) -> CL_tuple (fold_clexp clexp, n)
+ | CL_rmw _ -> assert false
+ | CL_field (clexp, field) -> CL_field (fold_clexp true clexp, field)
+ | CL_addr clexp -> CL_addr (fold_clexp true clexp)
+ | CL_tuple (clexp, n) -> CL_tuple (fold_clexp true clexp, n)
| CL_void -> CL_void
in
@@ -479,10 +538,10 @@ let rename_variables graph root children =
let aux = match aux with
| I_funcall (clexp, extern, id, args) ->
let args = List.map fold_cval args in
- I_funcall (fold_clexp clexp, extern, id, args)
+ I_funcall (fold_clexp false clexp, extern, id, args)
| I_copy (clexp, cval) ->
let cval = fold_cval cval in
- I_copy (fold_clexp clexp, cval)
+ I_copy (fold_clexp false clexp, cval)
| I_decl (ctyp, id) ->
let i = get_count id + 1 in
counts := NameMap.add id i !counts;
@@ -505,7 +564,7 @@ let rename_variables graph root children =
in
let ssa_cfnode = function
- | CF_start -> CF_start
+ | CF_start inits -> CF_start inits
| CF_block instrs -> CF_block (List.map ssa_instr instrs)
| CF_label label -> CF_label label
| CF_guard cval -> CF_guard (fold_cval cval)
@@ -524,7 +583,7 @@ let rename_variables graph root children =
| Phi (id, ctyp, ids) ->
let fix_arg k a =
if k = j then
- let i = top_stack a in
+ let i = top_stack_phi a ctyp in
ssa_name i a
else a
in
@@ -556,7 +615,11 @@ let rename_variables graph root children =
IntSet.iter (fun child -> rename child) (children.(n));
stacks := old_stacks
in
- rename root
+ rename root;
+ match graph.nodes.(root) with
+ | Some ((ssa, CF_start _), preds, succs) ->
+ graph.nodes.(root) <- Some ((ssa, CF_start !phi_zeros), preds, succs)
+ | _ -> failwith "root node is not CF_start"
let place_pi_functions graph start idom children =
let get_guard = function
@@ -631,11 +694,11 @@ let string_of_phis = function
let string_of_node = function
| (phis, CF_label label) -> string_of_phis phis ^ label
| (phis, CF_block instrs) -> string_of_phis phis ^ Util.string_of_list "\\l" (fun instr -> String.escaped (Pretty_print_sail.to_string (pp_instr ~short:true instr))) instrs
- | (phis, CF_start) -> string_of_phis phis ^ "START"
+ | (phis, CF_start inits) -> string_of_phis phis ^ "START"
| (phis, CF_guard cval) -> string_of_phis phis ^ (String.escaped (Pretty_print_sail.to_string (pp_cval cval)))
let vertex_color = function
- | (_, CF_start) -> "peachpuff"
+ | (_, CF_start _) -> "peachpuff"
| (_, CF_block _) -> "white"
| (_, CF_label _) -> "springgreen"
| (_, CF_guard _) -> "yellow"
diff --git a/src/jib/jib_ssa.mli b/src/jib/jib_ssa.mli
index b146861c..88bb46c0 100644
--- a/src/jib/jib_ssa.mli
+++ b/src/jib/jib_ssa.mli
@@ -49,6 +49,7 @@
(**************************************************************************)
open Array
+open Jib_util
(** A mutable array based graph type, with nodes indexed by integers. *)
type 'a array_graph
@@ -58,11 +59,11 @@ type 'a array_graph
val make : initial_size:int -> unit -> 'a array_graph
module IntSet : Set.S with type elt = int
-
+
val get_vertex : 'a array_graph -> int -> ('a * IntSet.t * IntSet.t) option
val iter_graph : ('a -> IntSet.t -> IntSet.t -> unit) -> 'a array_graph -> unit
-
+
(** Add a vertex to a graph, returning the index of the inserted
vertex. If the number of vertices exceeds the size of the
underlying array, then it is dynamically resized. *)
@@ -72,11 +73,13 @@ val add_vertex : 'a -> 'a array_graph -> int
if either of the vertices do not exist. *)
val add_edge : int -> int -> 'a array_graph -> unit
+val topsort : 'a array_graph -> int list
+
type cf_node =
| CF_label of string
| CF_block of Jib.instr list
| CF_guard of Jib.cval
- | CF_start
+ | CF_start of Jib.ctyp NameMap.t
val control_flow_graph : Jib.instr list -> int * int list * ('a list * cf_node) array_graph
diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml
index 904e0209..2eabdc57 100644
--- a/src/jib/jib_util.ml
+++ b/src/jib/jib_util.ml
@@ -164,6 +164,7 @@ let rec frag_rename from_id to_id = function
| F_op (f1, op, f2) -> F_op (frag_rename from_id to_id f1, op, frag_rename from_id to_id f2)
| F_unary (op, f) -> F_unary (op, frag_rename from_id to_id f)
| F_field (f, field) -> F_field (frag_rename from_id to_id f, field)
+ | F_tuple_member (f, len, n) -> F_tuple_member (frag_rename from_id to_id f, len, n)
| F_raw raw -> F_raw raw
| F_ctor_kind (f, ctor, unifiers, ctyp) -> F_ctor_kind (frag_rename from_id to_id f, ctor, unifiers, ctyp)
| F_ctor_unwrap (ctor, unifiers, f) -> F_ctor_unwrap (ctor, unifiers, frag_rename from_id to_id f)
@@ -258,7 +259,7 @@ let string_of_value = function
| V_bit Sail2_values.BU -> failwith "Undefined bit found in value"
| V_string str -> "\"" ^ str ^ "\""
-let string_of_name ?zencode:(zencode=true) =
+let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) =
let ssa_num n = if n < 0 then "" else ("/" ^ string_of_int n) in
function
| Name (id, n) ->
@@ -267,8 +268,10 @@ let string_of_name ?zencode:(zencode=true) =
"have_exception" ^ ssa_num n
| Return n ->
"return" ^ ssa_num n
- | Current_exception n ->
+ | Current_exception n when dce ->
"(*current_exception)" ^ ssa_num n
+ | Current_exception n ->
+ "current_exception" ^ ssa_num n
let rec string_of_fragment ?zencode:(zencode=true) = function
| F_id id -> string_of_name ~zencode:zencode id
@@ -278,6 +281,8 @@ let rec string_of_fragment ?zencode:(zencode=true) = function
Printf.sprintf "%s(%s)" str (Util.string_of_list ", " (string_of_fragment ~zencode:zencode) frags)
| F_field (f, field) ->
Printf.sprintf "%s.%s" (string_of_fragment' ~zencode:zencode f) field
+ | F_tuple_member (f, _, n) ->
+ Printf.sprintf "%s.ztup%d" (string_of_fragment' ~zencode:zencode f) n
| F_op (f1, op, f2) ->
Printf.sprintf "%s %s %s" (string_of_fragment' ~zencode:zencode f1) op (string_of_fragment' ~zencode:zencode f2)
| F_unary (op, f) ->
@@ -314,6 +319,7 @@ and string_of_ctyp = function
| CT_sbits (n, true) -> "sbits(" ^ string_of_int n ^ ", dec)"
| CT_sbits (n, false) -> "sbits(" ^ string_of_int n ^ ", inc)"
| CT_fint n -> "int(" ^ string_of_int n ^ ")"
+ | CT_constant n -> "constant(" ^ Big_int.to_string n ^ ")"
| CT_bit -> "bit"
| CT_unit -> "unit"
| CT_bool -> "bool"
@@ -348,7 +354,7 @@ and full_string_of_ctyp = function
| ctyp -> string_of_ctyp ctyp
let rec map_ctyp f = function
- | (CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _
+ | (CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _
| CT_bit | CT_unit | CT_bool | CT_real | CT_string | CT_poly | CT_enum _) as ctyp -> f ctyp
| CT_tup ctyps -> f (CT_tup (List.map (map_ctyp f) ctyps))
| CT_ref ctyp -> f (CT_ref (map_ctyp f ctyp))
@@ -365,6 +371,7 @@ let rec ctyp_equal ctyp1 ctyp2 =
| CT_fbits (m1, d1), CT_fbits (m2, d2) -> m1 = m2 && d1 = d2
| CT_bit, CT_bit -> true
| CT_fint n, CT_fint m -> n = m
+ | CT_constant n, CT_constant m -> Big_int.equal n m
| CT_unit, CT_unit -> true
| CT_bool, CT_bool -> true
| CT_struct (id1, _), CT_struct (id2, _) -> Id.compare id1 id2 = 0
@@ -391,6 +398,10 @@ let rec ctyp_compare ctyp1 ctyp2 =
| CT_fint _, _ -> 1
| _, CT_fint _ -> -1
+ | CT_constant n, CT_constant m -> Big_int.compare n m
+ | CT_constant _, _ -> 1
+ | _, CT_constant _ -> -1
+
| CT_fbits (n, ord1), CT_fbits (m, ord2) -> lex_ord (compare n m) (compare ord1 ord2)
| CT_fbits _, _ -> 1
| _, CT_fbits _ -> -1
@@ -465,7 +476,7 @@ let rec ctyp_unify ctyp1 ctyp2 =
List.concat (List.map2 ctyp_unify (List.map snd fields1) (List.map snd fields2))
else
raise (Invalid_argument "ctyp_unify")
-
+
| CT_ref ctyp1, CT_ref ctyp2 -> ctyp_unify ctyp1 ctyp2
| CT_poly, _ -> [ctyp2]
@@ -479,6 +490,7 @@ let rec ctyp_suprema = function
| CT_fbits (_, d) -> CT_lbits d
| CT_sbits (_, d) -> CT_lbits d
| CT_fint _ -> CT_lint
+ | CT_constant _ -> CT_lint
| CT_unit -> CT_unit
| CT_bool -> CT_bool
| CT_real -> CT_real
@@ -503,7 +515,7 @@ let rec ctyp_ids = function
IdSet.add id (List.fold_left (fun ids (_, ctyp) -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctors)
| CT_tup ctyps -> List.fold_left (fun ids ctyp -> IdSet.union (ctyp_ids ctyp) ids) IdSet.empty ctyps
| CT_vector (_, ctyp) | CT_list ctyp | CT_ref ctyp -> ctyp_ids ctyp
- | CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit
+ | CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_unit
| CT_bool | CT_real | CT_bit | CT_string | CT_poly -> IdSet.empty
let rec unpoly = function
@@ -515,7 +527,7 @@ let rec unpoly = function
| f -> f
let rec is_polymorphic = function
- | CT_lint | CT_fint _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string -> false
+ | CT_lint | CT_fint _ | CT_constant _ | CT_lbits _ | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real | CT_string -> false
| CT_tup ctyps -> List.exists is_polymorphic ctyps
| CT_enum _ -> false
| CT_struct (_, ctors) | CT_variant (_, ctors) -> List.exists (fun (_, ctyp) -> is_polymorphic ctyp) ctors
@@ -529,7 +541,7 @@ let pp_name id =
string (string_of_name ~zencode:false id)
let pp_ctyp ctyp =
- string (full_string_of_ctyp ctyp |> Util.yellow |> Util.clear)
+ string (string_of_ctyp ctyp |> Util.yellow |> Util.clear)
let pp_keyword str =
string ((str |> Util.red |> Util.clear) ^ " ")
@@ -539,6 +551,8 @@ let pp_cval (frag, ctyp) =
let rec pp_clexp = function
| CL_id (id, ctyp) -> pp_name id ^^ string " : " ^^ pp_ctyp ctyp
+ | CL_rmw (read, write, ctyp) ->
+ string "rmw" ^^ parens (pp_name read ^^ comma ^^ space ^^ pp_name write) ^^ string " : " ^^ pp_ctyp ctyp
| CL_field (clexp, field) -> parens (pp_clexp clexp) ^^ string "." ^^ string field
| CL_tuple (clexp, n) -> parens (pp_clexp clexp) ^^ string "." ^^ string (string_of_int n)
| CL_addr clexp -> string "*" ^^ pp_clexp clexp
@@ -643,7 +657,7 @@ let pp_cdef = function
let rec fragment_deps = function
| F_id id | F_ref id -> NameSet.singleton id
| F_lit _ -> NameSet.empty
- | F_field (frag, _) | F_unary (_, frag) | F_poly frag -> fragment_deps frag
+ | F_field (frag, _) | F_unary (_, frag) | F_poly frag | F_tuple_member (frag, _, _) -> fragment_deps frag
| F_call (_, frags) -> List.fold_left NameSet.union NameSet.empty (List.map fragment_deps frags)
| F_op (frag1, _, frag2) -> NameSet.union (fragment_deps frag1) (fragment_deps frag2)
| F_ctor_kind (frag, _, _, _) -> fragment_deps frag
@@ -653,11 +667,12 @@ let rec fragment_deps = function
let cval_deps = function (frag, _) -> fragment_deps frag
let rec clexp_deps = function
- | CL_id (id, _) -> NameSet.singleton id
+ | CL_id (id, _) -> NameSet.empty, NameSet.singleton id
+ | CL_rmw (read, write, _) -> NameSet.singleton read, NameSet.singleton write
| CL_field (clexp, _) -> clexp_deps clexp
| CL_tuple (clexp, _) -> clexp_deps clexp
| CL_addr clexp -> clexp_deps clexp
- | CL_void -> NameSet.empty
+ | CL_void -> NameSet.empty, NameSet.empty
(* Return the direct, read/write dependencies of a single instruction *)
let instr_deps = function
@@ -666,8 +681,12 @@ let instr_deps = function
| I_init (ctyp, id, cval) | I_reinit (ctyp, id, cval) -> cval_deps cval, NameSet.singleton id
| I_if (cval, _, _, _) -> cval_deps cval, NameSet.empty
| I_jump (cval, label) -> cval_deps cval, NameSet.empty
- | I_funcall (clexp, _, _, cvals) -> List.fold_left NameSet.union NameSet.empty (List.map cval_deps cvals), clexp_deps clexp
- | I_copy (clexp, cval) -> cval_deps cval, clexp_deps clexp
+ | I_funcall (clexp, _, _, cvals) ->
+ let reads, writes = clexp_deps clexp in
+ List.fold_left NameSet.union reads (List.map cval_deps cvals), writes
+ | I_copy (clexp, cval) ->
+ let reads, writes = clexp_deps clexp in
+ NameSet.union reads (cval_deps cval), writes
| I_clear (_, id) -> NameSet.singleton id, NameSet.empty
| I_throw cval | I_return cval -> cval_deps cval, NameSet.empty
| I_block _ | I_try_block _ -> NameSet.empty, NameSet.empty
@@ -690,6 +709,7 @@ module NameCTMap = Map.Make(NameCT)
let rec clexp_typed_writes = function
| CL_id (id, ctyp) -> NameCTSet.singleton (id, ctyp)
+ | CL_rmw (_, id, ctyp) -> NameCTSet.singleton (id, ctyp)
| CL_field (clexp, _) -> clexp_typed_writes clexp
| CL_tuple (clexp, _) -> clexp_typed_writes clexp
| CL_addr clexp -> clexp_typed_writes clexp
@@ -704,6 +724,7 @@ let instr_typed_writes (I_aux (aux, _)) =
let rec map_clexp_ctyp f = function
| CL_id (id, ctyp) -> CL_id (id, f ctyp)
+ | CL_rmw (read, write, ctyp) -> CL_rmw (read, write, f ctyp)
| CL_field (clexp, field) -> CL_field (map_clexp_ctyp f clexp, field)
| CL_tuple (clexp, n) -> CL_tuple (map_clexp_ctyp f clexp, n)
| CL_addr clexp -> CL_addr (map_clexp_ctyp f clexp)
@@ -837,6 +858,7 @@ let cval_ctyp = function (_, ctyp) -> ctyp
let rec clexp_ctyp = function
| CL_id (_, ctyp) -> ctyp
+ | CL_rmw (_, _, ctyp) -> ctyp
| CL_field (clexp, field) ->
begin match clexp_ctyp clexp with
| CT_struct (id, ctors) ->