summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/jib/anf.ml36
-rw-r--r--src/jib/anf.mli5
-rw-r--r--src/jib/c_backend.ml48
-rw-r--r--src/jib/jib_compile.ml2
-rw-r--r--src/jib/jib_optimize.ml9
-rw-r--r--src/jib/jib_smt.ml25
-rw-r--r--src/jib/jib_util.ml7
7 files changed, 100 insertions, 32 deletions
diff --git a/src/jib/anf.ml b/src/jib/anf.ml
index 4bb24032..3edd2cd7 100644
--- a/src/jib/anf.ml
+++ b/src/jib/anf.ml
@@ -153,6 +153,21 @@ let rec apat_rename from_id to_id (AP_aux (apat_aux, env, l)) =
in
AP_aux (apat_aux, env, l)
+let rec aval_typ = function
+ | AV_lit (_, typ) -> typ
+ | AV_id (_, lvar) -> lvar_typ lvar
+ | AV_ref (_, lvar) -> lvar_typ lvar
+ | AV_tuple avals -> tuple_typ (List.map aval_typ avals)
+ | AV_list (_, typ) -> typ
+ | AV_vector (_, typ) -> typ
+ | AV_record (_, typ) -> typ
+ | AV_cval (_, typ) -> typ
+
+let aexp_typ (AE_aux (aux, _, _)) =
+ match aux with
+ | AE_val aval -> aval_typ aval
+ | AE_app (_, _, typ) -> typ
+
let rec aval_rename from_id to_id = function
| AV_lit (lit, typ) -> AV_lit (lit, typ)
| AV_id (id, lvar) when Id.compare id from_id = 0 -> AV_id (to_id, lvar)
@@ -298,6 +313,27 @@ let rec map_functions f (AE_aux (aexp, env, l)) =
in
AE_aux (aexp, env, l)
+let rec fold_aexp f (AE_aux (aexp, env, l)) =
+ let aexp = match aexp with
+ | AE_app (id, vs, typ) -> AE_app (id, vs, typ)
+ | AE_cast (aexp, typ) -> AE_cast (fold_aexp f aexp, typ)
+ | AE_assign (id, typ, aexp) -> AE_assign (id, typ, fold_aexp f aexp)
+ | AE_short_circuit (op, aval, aexp) -> AE_short_circuit (op, aval, fold_aexp f aexp)
+ | AE_let (mut, id, typ1, aexp1, aexp2, typ2) -> AE_let (mut, id, typ1, fold_aexp f aexp1, fold_aexp f aexp2, typ2)
+ | AE_block (aexps, aexp, typ) -> AE_block (List.map (fold_aexp f) aexps, fold_aexp f aexp, typ)
+ | AE_if (aval, aexp1, aexp2, typ) ->
+ AE_if (aval, fold_aexp f aexp1, fold_aexp f aexp2, typ)
+ | AE_loop (loop_typ, aexp1, aexp2) -> AE_loop (loop_typ, fold_aexp f aexp1, fold_aexp f aexp2)
+ | AE_for (id, aexp1, aexp2, aexp3, order, aexp4) ->
+ AE_for (id, fold_aexp f aexp1, fold_aexp f aexp2, fold_aexp f aexp3, order, fold_aexp f aexp4)
+ | AE_case (aval, cases, typ) ->
+ AE_case (aval, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ)
+ | AE_try (aexp, cases, typ) ->
+ AE_try (fold_aexp f aexp, List.map (fun (pat, aexp1, aexp2) -> pat, fold_aexp f aexp1, fold_aexp f aexp2) cases, typ)
+ | AE_field _ | AE_record_update _ | AE_val _ | AE_return _ | AE_throw _ as v -> v
+ in
+ f (AE_aux (aexp, env, l))
+
(* For debugging we provide a pretty printer for ANF expressions. *)
let pp_lvar lvar doc =
diff --git a/src/jib/anf.mli b/src/jib/anf.mli
index 6bc274e6..571546cb 100644
--- a/src/jib/anf.mli
+++ b/src/jib/anf.mli
@@ -134,12 +134,17 @@ val gensym : unit -> id
(** {2 Functions for transforming ANF expressions} *)
+val aval_typ : typ aval -> typ
+val aexp_typ : typ aexp -> typ
+
(** Map over all values in an ANF expression *)
val map_aval : (Env.t -> Ast.l -> 'a aval -> 'a aval) -> 'a aexp -> 'a aexp
(** Map over all function calls in an ANF expression *)
val map_functions : (Env.t -> Ast.l -> id -> ('a aval) list -> 'a -> 'a aexp_aux) -> 'a aexp -> 'a aexp
+val fold_aexp : ('a aexp -> 'a aexp) -> 'a aexp -> 'a aexp
+
(** Remove all variable shadowing in an ANF expression *)
val no_shadow : IdSet.t -> 'a aexp -> 'a aexp
diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml
index f6d8dd80..0fe986e4 100644
--- a/src/jib/c_backend.ml
+++ b/src/jib/c_backend.ml
@@ -405,40 +405,46 @@ let analyze_primop' ctx id args typ =
match extern, args with
(*
- | "eq_bits", [AV_cval (v1, _, CT_fbits _); AV_cval (v2, _, _)] ->
- AE_val (AV_cval (F_op (v1, "==", v2), typ, CT_bool))
- | "eq_bits", [AV_cval (v1, _, CT_sbits _); AV_cval (v2, _, _)] ->
- AE_val (AV_cval (F_call ("eq_sbits", [v1; v2]), typ, CT_bool))
-
- | "neq_bits", [AV_cval (v1, _, CT_fbits _); AV_cval (v2, _, _)] ->
- AE_val (AV_cval (F_op (v1, "!=", v2), typ, CT_bool))
- | "neq_bits", [AV_cval (v1, _, CT_sbits _); AV_cval (v2, _, _)] ->
- AE_val (AV_cval (F_call ("neq_sbits", [v1; v2]), typ, CT_bool))
-
- | "eq_int", [AV_cval (v1, typ1, _); AV_cval (v2, typ2, _)] ->
- AE_val (AV_cval (F_op (v1, "==", v2), typ, CT_bool))
+ | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] ->
+ begin match cval_ctyp v1 with
+ | CT_fbits _ ->
+ AE_val (AV_cval (V_op (v1, "==", v2), typ))
+ | CT_sbits _ ->
+ AE_val (AV_cval (V_call ("eq_sbits", [v1; v2]), typ))
+ | _ -> no_change
+ end
- | "zeros", [_] ->
- begin match destruct_vector ctx.tc_env typ with
- | Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _))
- when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) ->
- AE_val (AV_C_fragment (F_raw "0x0", typ, CT_fbits (Big_int.to_int n, true)))
+ | "neq_bits", [AV_cval (v1, _); AV_cval (v2, _)] ->
+ begin match cval_ctyp v1 with
+ | CT_fbits _ ->
+ AE_val (AV_cval (V_op (v1, "!=", v2), typ))
+ | CT_sbits _ ->
+ AE_val (AV_cval (V_call ("neq_sbits", [v1; v2]), typ))
| _ -> no_change
end
- | "zero_extend", [AV_C_fragment (v1, _, CT_fbits _); _] ->
+ | "eq_int", [AV_cval (v1, _); AV_cval (v2, _)] ->
+ AE_val (AV_cval (V_op (v1, "==", v2), typ))
+
+ | "zeros", [_] ->
begin match destruct_vector ctx.tc_env typ with
| Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _))
when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) ->
- AE_val (AV_C_fragment (v1, typ, CT_fbits (Big_int.to_int n, true)))
+ let n = Big_int.to_int n in
+ AE_val (AV_cval (V_lit (VL_bits (Util.list_init n (fun _ -> Sail2_values.B0), true), CT_fbits (n, true)), typ))
| _ -> no_change
end
- | "zero_extend", [AV_C_fragment (v1, _, CT_sbits _); _] ->
+ | "zero_extend", [AV_cval (v1, _); _] ->
begin match destruct_vector ctx.tc_env typ with
| Some (Nexp_aux (Nexp_constant n, _), _, Typ_aux (Typ_id id, _))
when string_of_id id = "bit" && Big_int.less_equal n (Big_int.of_int 64) ->
- AE_val (AV_C_fragment (F_call ("fast_zero_extend", [v1; v_int (Big_int.to_int n)]), typ, CT_fbits (Big_int.to_int n, true)))
+ begin match cval_ctyp v1 with
+ | CT_fbits _ ->
+ AE_val (AV_C_fragment (v1, typ, CT_fbits (Big_int.to_int n, true)))
+ | CT_sbits _ ->
+ AE_val (AV_C_fragment (F_call ("fast_zero_extend", [v1; v_int (Big_int.to_int n)]), typ, CT_fbits (Big_int.to_int n, true)))
+ end
| _ -> no_change
end
diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml
index bc3314ec..a8a0c640 100644
--- a/src/jib/jib_compile.ml
+++ b/src/jib/jib_compile.ml
@@ -335,7 +335,7 @@ let rec compile_aval l ctx = function
[icopy l (CL_id (gs, ctyp)) (V_op (V_id (gs, ctyp), "|", V_lit (mask i, ctyp)))]
| _ ->
(* FIXME: Make this work in C *)
- setup @ [iif (V_unary ("bit_to_bool", cval)) [icopy l (CL_id (gs, ctyp)) (V_op (V_id (gs, ctyp), "|", V_lit (mask i, ctyp)))] [] CT_unit] @ cleanup
+ setup @ [iif (V_call ("bit_to_bool", [cval])) [icopy l (CL_id (gs, ctyp)) (V_op (V_id (gs, ctyp), "|", V_lit (mask i, ctyp)))] [] CT_unit] @ cleanup
in
[idecl ctyp gs;
icopy l (CL_id (gs, ctyp)) (V_lit (VL_bits (Util.list_init 64 (fun _ -> Sail2_values.B0), direction), ctyp))]
diff --git a/src/jib/jib_optimize.ml b/src/jib/jib_optimize.ml
index 3fc42aa3..331cf65e 100644
--- a/src/jib/jib_optimize.ml
+++ b/src/jib/jib_optimize.ml
@@ -169,6 +169,8 @@ let rec cval_subst id subst = function
| V_ctor_unwrap (ctor, cval, unifiers, ctyp) -> V_ctor_unwrap (ctor, cval_subst id subst cval, unifiers, ctyp)
| V_struct (fields, ctyp) -> V_struct (List.map (fun (field, cval) -> field, cval_subst id subst cval) fields, ctyp)
| V_poly (cval, ctyp) -> V_poly (cval_subst id subst cval, ctyp)
+ | V_hd cval -> V_hd (cval_subst id subst cval)
+ | V_tl cval -> V_tl (cval_subst id subst cval)
let rec cval_map_id f = function
| V_id (id, ctyp) -> V_id (f id, ctyp)
@@ -231,16 +233,13 @@ let rec instrs_subst id subst =
| [] -> []
let rec clexp_subst id subst = function
- | CL_id (id', ctyp) when Name.compare id id' = 0 ->
- if ctyp_equal ctyp (clexp_ctyp subst) then
- subst
- else
- subst
+ | CL_id (id', ctyp) when Name.compare id id' = 0 -> subst
| CL_id (id', ctyp) -> CL_id (id', ctyp)
| CL_field (clexp, field) -> CL_field (clexp_subst id subst clexp, field)
| CL_addr clexp -> CL_addr (clexp_subst id subst clexp)
| CL_tuple (clexp, n) -> CL_tuple (clexp_subst id subst clexp, n)
| CL_void -> CL_void
+ | CL_rmw _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot substitute into read-modify-write construct"
let rec find_function fid = function
| CDEF_fundef (fid', heap_return, args, body) :: _ when Id.compare fid fid' = 0 ->
diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml
index 44f8e24b..76239b35 100644
--- a/src/jib/jib_smt.ml
+++ b/src/jib/jib_smt.ml
@@ -177,7 +177,7 @@ let rec smt_cval env cval =
Fn ("not", [Fn ("=", [smt_cval env frag1; smt_cval env frag2])])
| V_op (frag1, "|", frag2) ->
Fn ("bvor", [smt_cval env frag1; smt_cval env frag2])
- | V_unary ("bit_to_bool", cval) ->
+ | V_call ("bit_to_bool", [cval]) ->
Fn ("=", [smt_cval env cval; Bin "1"])
| V_unary ("!", cval) ->
Fn ("not", [smt_cval env cval])
@@ -1029,6 +1029,27 @@ let c_literals ctx =
in
map_aval c_literal
+let unroll_foreach ctx = function
+ | AE_aux (AE_for (id, from_aexp, to_aexp, by_aexp, order, body), env, l) as aexp ->
+ begin match ctyp_of_typ ctx (aexp_typ from_aexp), ctyp_of_typ ctx (aexp_typ to_aexp), ctyp_of_typ ctx (aexp_typ by_aexp), order with
+ | CT_constant f, CT_constant t, CT_constant b, Ord_aux (Ord_inc, _) ->
+ let i = ref f in
+ let unrolled = ref [] in
+ while Big_int.less_equal !i t do
+ let current_index = AE_aux (AE_val (AV_lit (L_aux (L_num !i, gen_loc l), atom_typ (nconstant !i))), env, gen_loc l) in
+ let iteration = AE_aux (AE_let (Immutable, id, atom_typ (nconstant !i), current_index, body, unit_typ), env, gen_loc l) in
+ unrolled := iteration :: !unrolled;
+ i := Big_int.add !i b
+ done;
+ begin match !unrolled with
+ | last :: iterations ->
+ AE_aux (AE_block (List.rev iterations, last, unit_typ), env, gen_loc l)
+ | [] -> AE_aux (AE_val (AV_lit (L_aux (L_unit, gen_loc l), unit_typ)), env, gen_loc l)
+ end
+ | _ -> aexp
+ end
+ | aexp -> aexp
+
(**************************************************************************)
(* 3. Generating SMT *)
(**************************************************************************)
@@ -1371,7 +1392,7 @@ let generate_smt props name_file env ast =
let ctx =
initial_ctx
~convert_typ:ctyp_of_typ
- ~optimize_anf:(fun ctx aexp -> c_literals ctx aexp)
+ ~optimize_anf:(fun ctx aexp -> fold_aexp (unroll_foreach ctx) (c_literals ctx aexp))
env
in
let t = Profile.start () in
diff --git a/src/jib/jib_util.ml b/src/jib/jib_util.ml
index df2ce369..7b0b4f92 100644
--- a/src/jib/jib_util.ml
+++ b/src/jib/jib_util.ml
@@ -280,9 +280,9 @@ let string_of_name ?deref_current_exception:(dce=true) ?zencode:(zencode=true) =
"current_exception" ^ ssa_num n
let rec string_of_cval ?zencode:(zencode=true) = function
- | V_id (id, ctyp) -> string_of_name ~zencode:zencode id ^ " : " ^ string_of_ctyp ctyp
+ | V_id (id, ctyp) -> string_of_name ~zencode:zencode id
| V_ref (id, _) -> "&" ^ string_of_name ~zencode:zencode id
- | V_lit (vl, ctyp) -> string_of_value vl ^ " : " ^ string_of_ctyp ctyp
+ | V_lit (vl, ctyp) -> string_of_value vl
| V_call (str, cvals) ->
Printf.sprintf "%s(%s)" str (Util.string_of_list ", " (string_of_cval ~zencode:zencode) cvals)
| V_field (f, field) ->
@@ -675,6 +675,7 @@ let rec cval_deps = function
| V_ctor_kind (cval, _, _, _) -> cval_deps cval
| V_ctor_unwrap (_, cval, _, _) -> cval_deps cval
| V_hd cval | V_tl cval -> cval_deps cval
+ | V_struct (fields, ctyp) -> List.fold_left (fun ns (_, cval) -> NameSet.union ns (cval_deps cval)) NameSet.empty fields
let rec clexp_deps = function
| CL_id (id, _) -> NameSet.empty, NameSet.singleton id
@@ -890,7 +891,6 @@ let label str =
let rec infer_unary v = function
| "!" -> CT_bool
- | "bit_to_bool" -> CT_bool
| op -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Could not infer unary " ^ op)
and infer_op v1 v2 = function
@@ -907,6 +907,7 @@ and infer_op v1 v2 = function
| op -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Cannot infer binary op: " ^ op)
and infer_call vs = function
+ | "bit_to_bool" -> CT_bool
| op -> Reporting.unreachable Parse_ast.Unknown __POS__ ("Cannot infer call: " ^ op)
and cval_ctyp = function