diff options
| author | Alasdair Armstrong | 2019-07-16 18:57:46 +0100 |
|---|---|---|
| committer | Alasdair Armstrong | 2019-07-16 18:57:46 +0100 |
| commit | cd909e15b97739b10214023af04b2fbbb4d20cf7 (patch) | |
| tree | 9a418c7cafa915c29e93242848a1411cbd8b8f7c /src | |
| parent | 6d3a6edcd616621eb40420cfb16a34762a32c5c1 (diff) | |
| parent | 170543faa031d90186e6b45612ed8374f1c25f7b (diff) | |
Merge remote-tracking branch 'origin/sail2' into separate_bv
Diffstat (limited to 'src')
39 files changed, 856 insertions, 354 deletions
diff --git a/src/ast_util.ml b/src/ast_util.ml index b061600f..ac3c6d2b 100644 --- a/src/ast_util.ml +++ b/src/ast_util.ml @@ -335,6 +335,14 @@ let rec constraint_simp (NC_aux (nc_aux, l)) = | _, _ -> NC_bounded_ge (nexp1, nexp2) end + | NC_bounded_gt (nexp1, nexp2) -> + let nexp1, nexp2 = nexp_simp nexp1, nexp_simp nexp2 in + begin match nexp1, nexp2 with + | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> + if Big_int.greater c1 c2 then NC_true else NC_false + | _, _ -> NC_bounded_gt (nexp1, nexp2) + end + | NC_bounded_le (nexp1, nexp2) -> let nexp1, nexp2 = nexp_simp nexp1, nexp_simp nexp2 in begin match nexp1, nexp2 with @@ -343,6 +351,14 @@ let rec constraint_simp (NC_aux (nc_aux, l)) = | _, _ -> NC_bounded_le (nexp1, nexp2) end + | NC_bounded_lt (nexp1, nexp2) -> + let nexp1, nexp2 = nexp_simp nexp1, nexp_simp nexp2 in + begin match nexp1, nexp2 with + | Nexp_aux (Nexp_constant c1, _), Nexp_aux (Nexp_constant c2, _) -> + if Big_int.less c1 c2 then NC_true else NC_false + | _, _ -> NC_bounded_lt (nexp1, nexp2) + end + | _ -> nc_aux in NC_aux (nc_aux, l) @@ -419,7 +435,9 @@ let nc_int_set kid ints = mk_nc (NC_set (kid, List.map Big_int.of_int ints)) let nc_eq n1 n2 = mk_nc (NC_equal (n1, n2)) let nc_neq n1 n2 = mk_nc (NC_not_equal (n1, n2)) let nc_lteq n1 n2 = NC_aux (NC_bounded_le (n1, n2), Parse_ast.Unknown) +let nc_lt n1 n2 = NC_aux (NC_bounded_lt (n1, n2), Parse_ast.Unknown) let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown) +let nc_gt n1 n2 = NC_aux (NC_bounded_gt (n1, n2), Parse_ast.Unknown) let nc_lt n1 n2 = nc_lteq (nsum n1 (nint 1)) n2 let nc_gt n1 n2 = nc_gteq n1 (nsum n2 (nint 1)) let nc_var kid = mk_nc (NC_var kid) @@ -846,7 +864,9 @@ and string_of_n_constraint = function | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " == " ^ string_of_nexp n2 | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 | NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2 + | NC_aux (NC_bounded_gt (n1, n2), _) -> string_of_nexp n1 ^ " > " ^ string_of_nexp n2 | NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2 + | NC_aux (NC_bounded_lt (n1, n2), _) -> string_of_nexp n1 ^ " < " ^ string_of_nexp n2 | NC_aux (NC_or (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_and (nc1, nc2), _) -> @@ -1118,7 +1138,9 @@ let rec nc_compare (NC_aux (nc1,_)) (NC_aux (nc2,_)) = match nc1, nc2 with | NC_equal (n1,n2), NC_equal (n3,n4) | NC_bounded_ge (n1,n2), NC_bounded_ge (n3,n4) + | NC_bounded_gt (n1,n2), NC_bounded_gt (n3,n4) | NC_bounded_le (n1,n2), NC_bounded_le (n3,n4) + | NC_bounded_lt (n1,n2), NC_bounded_lt (n3,n4) | NC_not_equal (n1,n2), NC_not_equal (n3,n4) -> lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4 | NC_set (k1,s1), NC_set (k2,s2) -> @@ -1135,7 +1157,9 @@ let rec nc_compare (NC_aux (nc1,_)) (NC_aux (nc2,_)) = -> 0 | NC_equal _, _ -> -1 | _, NC_equal _ -> 1 | NC_bounded_ge _, _ -> -1 | _, NC_bounded_ge _ -> 1 + | NC_bounded_gt _, _ -> -1 | _, NC_bounded_gt _ -> 1 | NC_bounded_le _, _ -> -1 | _, NC_bounded_le _ -> 1 + | NC_bounded_lt _, _ -> -1 | _, NC_bounded_lt _ -> 1 | NC_not_equal _, _ -> -1 | _, NC_not_equal _ -> 1 | NC_set _, _ -> -1 | _, NC_set _ -> 1 | NC_or _, _ -> -1 | _, NC_or _ -> 1 @@ -1343,7 +1367,9 @@ let rec kopts_of_constraint (NC_aux (nc, _)) = match nc with | NC_equal (nexp1, nexp2) | NC_bounded_ge (nexp1, nexp2) + | NC_bounded_gt (nexp1, nexp2) | NC_bounded_le (nexp1, nexp2) + | NC_bounded_lt (nexp1, nexp2) | NC_not_equal (nexp1, nexp2) -> KOptSet.union (kopts_of_nexp nexp1) (kopts_of_nexp nexp2) | NC_set (kid, _) -> KOptSet.singleton (mk_kopt K_int kid) @@ -1400,7 +1426,9 @@ let rec tyvars_of_constraint (NC_aux (nc, _)) = match nc with | NC_equal (nexp1, nexp2) | NC_bounded_ge (nexp1, nexp2) + | NC_bounded_gt (nexp1, nexp2) | NC_bounded_le (nexp1, nexp2) + | NC_bounded_lt (nexp1, nexp2) | NC_not_equal (nexp1, nexp2) -> KidSet.union (tyvars_of_nexp nexp1) (tyvars_of_nexp nexp2) | NC_set (kid, _) -> KidSet.singleton kid @@ -1686,7 +1714,9 @@ let rec locate_nc f (NC_aux (nc_aux, l)) = let nc_aux = match nc_aux with | NC_equal (nexp1, nexp2) -> NC_equal (locate_nexp f nexp1, locate_nexp f nexp2) | NC_bounded_ge (nexp1, nexp2) -> NC_bounded_ge (locate_nexp f nexp1, locate_nexp f nexp2) + | NC_bounded_gt (nexp1, nexp2) -> NC_bounded_gt (locate_nexp f nexp1, locate_nexp f nexp2) | NC_bounded_le (nexp1, nexp2) -> NC_bounded_le (locate_nexp f nexp1, locate_nexp f nexp2) + | NC_bounded_lt (nexp1, nexp2) -> NC_bounded_lt (locate_nexp f nexp1, locate_nexp f nexp2) | NC_not_equal (nexp1, nexp2) -> NC_not_equal (locate_nexp f nexp1, locate_nexp f nexp2) | NC_set (kid, nums) -> NC_set (locate_kid f kid, nums) | NC_or (nc1, nc2) -> NC_or (locate_nc f nc1, locate_nc f nc2) @@ -1862,7 +1892,13 @@ let order_subst_aux sv subst = function let order_subst sv subst (Ord_aux (ord, l)) = Ord_aux (order_subst_aux sv subst ord, l) -let rec nexp_subst sv subst (Nexp_aux (nexp, l)) = Nexp_aux (nexp_subst_aux sv subst nexp, l) +let rec nexp_subst sv subst = function + | (Nexp_aux (Nexp_var kid, l)) as nexp -> + begin match subst with + | A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> n + | _ -> nexp + end + | Nexp_aux (nexp, l) -> Nexp_aux (nexp_subst_aux sv subst nexp, l) and nexp_subst_aux sv subst = function | Nexp_var kid -> begin match subst with @@ -1887,7 +1923,9 @@ let rec constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_au and constraint_subst_aux l sv subst = function | NC_equal (n1, n2) -> NC_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_ge (n1, n2) -> NC_bounded_ge (nexp_subst sv subst n1, nexp_subst sv subst n2) + | NC_bounded_gt (n1, n2) -> NC_bounded_gt (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_le (n1, n2) -> NC_bounded_le (nexp_subst sv subst n1, nexp_subst sv subst n2) + | NC_bounded_lt (n1, n2) -> NC_bounded_lt (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_set (kid, ints) as set_nc -> begin match subst with @@ -1988,7 +2026,9 @@ let subst_kids_nc, subst_kids_typ, subst_kids_typ_arg = match nc with | NC_equal (n1,n2) -> re (NC_equal (snexp n1, snexp n2)) | NC_bounded_ge (n1,n2) -> re (NC_bounded_ge (snexp n1, snexp n2)) + | NC_bounded_gt (n1,n2) -> re (NC_bounded_gt (snexp n1, snexp n2)) | NC_bounded_le (n1,n2) -> re (NC_bounded_le (snexp n1, snexp n2)) + | NC_bounded_lt (n1,n2) -> re (NC_bounded_lt (snexp n1, snexp n2)) | NC_not_equal (n1,n2) -> re (NC_not_equal (snexp n1, snexp n2)) | NC_set (kid,is) -> begin diff --git a/src/constant_fold.ml b/src/constant_fold.ml index abedcf35..7a7067ef 100644 --- a/src/constant_fold.ml +++ b/src/constant_fold.ml @@ -191,7 +191,7 @@ let rec run frame = let initial_state ast env = Interpreter.initial_state ~registers:false ast env safe_primops -let rw_exp ok not_ok istate = +let rw_exp target ok not_ok istate = let evaluate e_aux annot = let initial_monad = Interpreter.return (E_aux (e_aux, annot)) in try @@ -220,7 +220,16 @@ let rw_exp ok not_ok istate = ok (); E_aux (E_lit (L_aux (L_unit, fst annot)), annot) | E_app (id, args) when List.for_all is_constant args -> - evaluate e_aux annot + let env = env_of_annot annot in + (* We want to fold all primitive operations, but avoid folding + non-primitives that are defined in target-specific way. *) + let is_primop = + Env.is_extern id env "interpreter" && StringMap.mem (Env.get_extern id env "interpreter") safe_primops + in + if not (Env.is_extern id env target) || is_primop then + evaluate e_aux annot + else + E_aux (e_aux, annot) | E_cast (typ, (E_aux (E_lit _, _) as lit)) -> ok (); lit @@ -243,9 +252,9 @@ let rw_exp ok not_ok istate = in fold_exp { id_exp_alg with e_aux = (fun (e_aux, annot) -> rw_funcall e_aux annot)} -let rewrite_exp_once = rw_exp (fun _ -> ()) (fun _ -> ()) +let rewrite_exp_once target = rw_exp target (fun _ -> ()) (fun _ -> ()) -let rec rewrite_constant_function_calls' ast = +let rec rewrite_constant_function_calls' target ast = let rewrite_count = ref 0 in let ok () = incr rewrite_count in let not_ok () = decr rewrite_count in @@ -253,16 +262,16 @@ let rec rewrite_constant_function_calls' ast = let rw_defs = { rewriters_base with - rewrite_exp = (fun _ -> rw_exp ok not_ok istate) + rewrite_exp = (fun _ -> rw_exp target ok not_ok istate) } in let ast = rewrite_defs_base rw_defs ast in (* We keep iterating until we have no more re-writes to do *) if !rewrite_count > 0 - then rewrite_constant_function_calls' ast + then rewrite_constant_function_calls' target ast else ast -let rewrite_constant_function_calls ast = +let rewrite_constant_function_calls target ast = if !optimize_constant_fold then - rewrite_constant_function_calls' ast + rewrite_constant_function_calls' target ast else ast diff --git a/src/constant_propagation.ml b/src/constant_propagation.ml index 201e43e7..00b3d192 100644 --- a/src/constant_propagation.ml +++ b/src/constant_propagation.ml @@ -301,7 +301,7 @@ let is_env_inconsistent env ksubsts = module StringSet = Set.Make(String) module StringMap = Map.Make(String) -let const_props defs ref_vars = +let const_props target defs ref_vars = let const_fold exp = (* Constant-fold function applications with constant arguments *) let interpreter_istate = @@ -316,7 +316,7 @@ let const_props defs ref_vars = try strip_exp exp |> infer_exp (env_of exp) - |> Constant_fold.rewrite_exp_once interpreter_istate + |> Constant_fold.rewrite_exp_once target interpreter_istate |> keep_undef_typ with | _ -> exp @@ -603,7 +603,7 @@ let const_props defs ref_vars = | E_assert (e1,e2) -> let e1',e2',assigns = non_det_exp_2 e1 e2 in re (E_assert (e1',e2')) assigns - + | E_app_infix _ | E_var _ | E_internal_plet _ @@ -803,15 +803,15 @@ let const_props defs ref_vars = | DoesMatch (subst,ksubst) -> Some (exp,subst,ksubst) | GiveUp -> None in findpat_generic (string_of_exp exp0) assigns cases - + and can_match exp = let env = Type_check.env_of exp in can_match_with_env env exp in (const_prop_exp, const_prop_pexp) -let const_prop d r = fst (const_props d r) -let const_prop_pexp d r = snd (const_props d r) +let const_prop target d r = fst (const_props target d r) +let const_prop_pexp target d r = snd (const_props target d r) let referenced_vars exp = let open Rewriter in diff --git a/src/constant_propagation.mli b/src/constant_propagation.mli index 437492c6..9c182cb0 100644 --- a/src/constant_propagation.mli +++ b/src/constant_propagation.mli @@ -59,6 +59,7 @@ open Type_check (and hence we cannot reliably track). *) val const_prop : + string -> tannot defs -> IdSet.t -> tannot exp Bindings.t * nexp KBindings.t -> diff --git a/src/constant_propagation_mutrec.ml b/src/constant_propagation_mutrec.ml index 285ba45d..6cc6d28c 100644 --- a/src/constant_propagation_mutrec.ml +++ b/src/constant_propagation_mutrec.ml @@ -130,7 +130,7 @@ let generate_val_spec env id args l annot = | _, Typ_aux (_, l) -> raise (Reporting.err_unreachable l __POS__ "Function val spec is not a function type") -let const_prop defs substs ksubsts exp = +let const_prop target defs substs ksubsts exp = (* Constant_propagation currently only supports nexps for kid substitutions *) let nexp_substs = KBindings.bindings ksubsts @@ -139,6 +139,7 @@ let const_prop defs substs ksubsts exp = |> List.fold_left (fun s (v,i) -> KBindings.add v i s) KBindings.empty in Constant_propagation.const_prop + target (Defs defs) (Constant_propagation.referenced_vars exp) (substs, nexp_substs) @@ -147,7 +148,7 @@ let const_prop defs substs ksubsts exp = |> fst (* Propagate constant arguments into function clause pexp *) -let prop_args_pexp defs ksubsts args pexp = +let prop_args_pexp target defs ksubsts args pexp = let pat, guard, exp, annot = destruct_pexp pexp in let pats = match pat with | P_aux (P_tup pats, _) -> pats @@ -164,14 +165,14 @@ let prop_args_pexp defs ksubsts args pexp = else (pat :: pats, substs) in let pats, substs = List.fold_right2 match_arg args pats ([], Bindings.empty) in - let exp' = const_prop defs substs ksubsts exp in + let exp' = const_prop target defs substs ksubsts exp in let pat' = match pats with | [pat] -> pat | _ -> P_aux (P_tup pats, (Parse_ast.Unknown, empty_tannot)) in construct_pexp (pat', guard, exp', annot) -let rewrite_defs env (Defs defs) = +let rewrite_defs target env (Defs defs) = let rec rewrite = function | [] -> [] | DEF_internal_mutrec mutrecs :: ds -> @@ -194,7 +195,7 @@ let rewrite_defs env (Defs defs) = let valspec, ksubsts = generate_val_spec env id args l annot in let const_prop_funcl (FCL_aux (FCL_Funcl (_, pexp), (l, _))) = let pexp' = - prop_args_pexp defs ksubsts args pexp + prop_args_pexp target defs ksubsts args pexp |> rewrite_pexp |> strip_pexp in @@ -215,7 +216,7 @@ let rewrite_defs env (Defs defs) = let pexp' = if List.exists (fun id' -> Id.compare id id' = 0) !targets then let pat, guard, body, annot = destruct_pexp pexp in - let body' = const_prop defs Bindings.empty KBindings.empty body in + let body' = const_prop target defs Bindings.empty KBindings.empty body in rewrite_pexp (construct_pexp (pat, guard, recheck_exp body', annot)) else pexp in FCL_aux (FCL_Funcl (id, pexp'), a) diff --git a/src/constraint.ml b/src/constraint.ml index 1bd6a437..6c34bc9b 100644 --- a/src/constraint.ml +++ b/src/constraint.ml @@ -179,17 +179,22 @@ let to_smt l vars constr = | Nexp_times (nexp1, nexp2) -> sfun "*" [smt_nexp nexp1; smt_nexp nexp2] | Nexp_sum (nexp1, nexp2) -> sfun "+" [smt_nexp nexp1; smt_nexp nexp2] | Nexp_minus (nexp1, nexp2) -> sfun "-" [smt_nexp nexp1; smt_nexp nexp2] - | Nexp_exp (Nexp_aux (Nexp_constant c, _)) when Big_int.greater c Big_int.zero -> - Atom (Big_int.to_string (Big_int.pow_int_positive 2 (Big_int.to_int c))) - | Nexp_exp nexp when !opt_solver.uninterpret_power -> sfun "sailexp" [smt_nexp nexp] - | Nexp_exp nexp -> sfun "^" [Atom "2"; smt_nexp nexp] + | Nexp_exp nexp -> + begin match nexp_simp nexp with + | Nexp_aux (Nexp_constant c, _) when Big_int.greater_equal c Big_int.zero -> + Atom (Big_int.to_string (Big_int.pow_int_positive 2 (Big_int.to_int c))) + | nexp when !opt_solver.uninterpret_power -> sfun "sailexp" [smt_nexp nexp] + | nexp -> sfun "^" [Atom "2"; smt_nexp nexp] + end | Nexp_neg nexp -> sfun "-" [smt_nexp nexp] in let rec smt_constraint (NC_aux (aux, l) : n_constraint) : sexpr = match aux with | NC_equal (nexp1, nexp2) -> sfun "=" [smt_nexp nexp1; smt_nexp nexp2] | NC_bounded_le (nexp1, nexp2) -> sfun "<=" [smt_nexp nexp1; smt_nexp nexp2] + | NC_bounded_lt (nexp1, nexp2) -> sfun "<" [smt_nexp nexp1; smt_nexp nexp2] | NC_bounded_ge (nexp1, nexp2) -> sfun ">=" [smt_nexp nexp1; smt_nexp nexp2] + | NC_bounded_gt (nexp1, nexp2) -> sfun ">" [smt_nexp nexp1; smt_nexp nexp2] | NC_not_equal (nexp1, nexp2) -> sfun "not" [sfun "=" [smt_nexp nexp1; smt_nexp nexp2]] | NC_set (v, ints) -> sfun "or" (List.map (fun i -> sfun "=" [smt_var v; Atom (Big_int.to_string i)]) ints) diff --git a/src/gen_lib/sail2_operators.lem b/src/gen_lib/sail2_operators.lem index 547160d3..43a9812e 100644 --- a/src/gen_lib/sail2_operators.lem +++ b/src/gen_lib/sail2_operators.lem @@ -163,9 +163,9 @@ let arith_op_bv_no0 op sign size l r = Maybe.bind (int_of_bv sign r) (fun r' -> if r' = 0 then Nothing else Just (of_int (length l * size) (op l' r')))) -let mod_bv = arith_op_bv_no0 hardware_mod false 1 -let quot_bv = arith_op_bv_no0 hardware_quot false 1 -let quots_bv = arith_op_bv_no0 hardware_quot true 1 +let mod_bv = arith_op_bv_no0 tmod_int false 1 +let quot_bv = arith_op_bv_no0 tdiv_int false 1 +let quots_bv = arith_op_bv_no0 tdiv_int true 1 let mod_mword = Machine_word.modulo let quot_mword = Machine_word.unsignedDivide @@ -174,8 +174,8 @@ let quots_mword = Machine_word.signedDivide let arith_op_bv_int_no0 op sign size l r = arith_op_bv_no0 op sign size l (of_int (length l) r) -let quot_bv_int = arith_op_bv_int_no0 hardware_quot false 1 -let mod_bv_int = arith_op_bv_int_no0 hardware_mod false 1 +let quot_bv_int = arith_op_bv_int_no0 tdiv_int false 1 +let mod_bv_int = arith_op_bv_int_no0 tmod_int false 1 let mod_mword_int l r = Machine_word.modulo l (wordFromInteger r) let quot_mword_int l r = Machine_word.unsignedDivide l (wordFromInteger r) diff --git a/src/gen_lib/sail2_operators_bitlists.lem b/src/gen_lib/sail2_operators_bitlists.lem index 8b75fa38..c9892e4c 100644 --- a/src/gen_lib/sail2_operators_bitlists.lem +++ b/src/gen_lib/sail2_operators_bitlists.lem @@ -304,3 +304,5 @@ val eq_vec : list bitU -> list bitU -> bool val neq_vec : list bitU -> list bitU -> bool let eq_vec = eq_bv let neq_vec = neq_bv + +let inline count_leading_zeros v = count_leading_zero_bits v diff --git a/src/gen_lib/sail2_operators_mwords.lem b/src/gen_lib/sail2_operators_mwords.lem index 181fa149..c8524e16 100644 --- a/src/gen_lib/sail2_operators_mwords.lem +++ b/src/gen_lib/sail2_operators_mwords.lem @@ -329,3 +329,6 @@ val eq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool val neq_vec : forall 'a. Size 'a => mword 'a -> mword 'a -> bool let inline eq_vec = eq_mword let inline neq_vec = neq_mword + +val count_leading_zeros : forall 'a. Size 'a => mword 'a -> integer +let count_leading_zeros v = count_leading_zeros_bv v diff --git a/src/gen_lib/sail2_values.lem b/src/gen_lib/sail2_values.lem index 5e6537a8..f657803f 100644 --- a/src/gen_lib/sail2_values.lem +++ b/src/gen_lib/sail2_values.lem @@ -104,21 +104,25 @@ let upper n = n (* Modulus operation corresponding to quot below -- result has sign of dividend. *) -let hardware_mod (a: integer) (b:integer) : integer = +let tmod_int (a: integer) (b:integer) : integer = let m = (abs a) mod (abs b) in if a < 0 then ~m else m +let hardware_mod = tmod_int + (* There are different possible answers for integer divide regarding rounding behaviour on negative operands. Positive operands always round down so derive the one we want (trucation towards zero) from that *) -let hardware_quot (a:integer) (b:integer) : integer = +let tdiv_int (a:integer) (b:integer) : integer = let q = (abs a) / (abs b) in if ((a<0) = (b<0)) then q (* same sign -- result positive *) else ~q (* different sign -- result negative *) +let hardware_quot = tdiv_int + let max_64u = (integerPow 2 64) - 1 let max_64 = (integerPow 2 63) - 1 let min_64 = 0 - (integerPow 2 63) @@ -652,6 +656,16 @@ let int_of_bit b = | _ -> failwith "int_of_bit saw unknown" end +val count_leading_zero_bits : list bitU -> integer +let rec count_leading_zero_bits v = + match v with + | B0 :: v' -> count_leading_zero_bits v' + 1 + | _ -> 0 + end + +val count_leading_zeros_bv : forall 'a. Bitvector 'a => 'a -> integer +let count_leading_zeros_bv v = count_leading_zero_bits (bits_of v) + val decimal_string_of_bv : forall 'a. Bitvector 'a => 'a -> string let decimal_string_of_bv bv = let place_values = diff --git a/src/initial_check.ml b/src/initial_check.ml index b2e3dc79..f41033ca 100644 --- a/src/initial_check.ml +++ b/src/initial_check.ml @@ -238,8 +238,8 @@ and to_ast_constraint ctx (P.ATyp_aux (aux, l) as atyp) = | "!=" -> NC_not_equal (to_ast_nexp ctx t1, to_ast_nexp ctx t2) | ">=" -> NC_bounded_ge (to_ast_nexp ctx t1, to_ast_nexp ctx t2) | "<=" -> NC_bounded_le (to_ast_nexp ctx t1, to_ast_nexp ctx t2) - | ">" -> NC_bounded_ge (to_ast_nexp ctx t1, nsum (to_ast_nexp ctx t2) (nint 1)) - | "<" -> NC_bounded_le (nsum (to_ast_nexp ctx t1) (nint 1), to_ast_nexp ctx t2) + | ">" -> NC_bounded_gt (to_ast_nexp ctx t1, to_ast_nexp ctx t2) + | "<" -> NC_bounded_lt (to_ast_nexp ctx t1, to_ast_nexp ctx t2) | "&" -> NC_and (to_ast_constraint ctx t1, to_ast_constraint ctx t2) | "|" -> NC_or (to_ast_constraint ctx t1, to_ast_constraint ctx t2) | _ -> diff --git a/src/isail.ml b/src/isail.ml index 9e9b6236..88408dcd 100644 --- a/src/isail.ml +++ b/src/isail.ml @@ -302,7 +302,8 @@ let rec describe_rewrite = | String_rewriter rw -> "<string>" :: describe_rewrite (rw "") | Bool_rewriter rw -> "<bool>" :: describe_rewrite (rw false) | Literal_rewriter rw -> "(ocaml|lem|all)" :: describe_rewrite (rw (fun _ -> true)) - | Basic_rewriter rw -> [] + | Basic_rewriter _ + | Checking_rewriter _ -> [] type session = { id : string; @@ -592,7 +593,9 @@ let handle_input' input = failwith "Must provide the name of a rewrite, use :list_rewrites for a list of possible rewrites" end | ":rewrites" -> - Interactive.ast := Process_file.rewrite_ast_target arg !Interactive.env !Interactive.ast; + let new_ast, new_env = Process_file.rewrite_ast_target arg !Interactive.env !Interactive.ast in + Interactive.ast := new_ast; + Interactive.env := new_env; interactive_state := initial_state !Interactive.ast !Interactive.env Value.primops | ":prover_regstate" -> let env, ast = prover_regstate (Some arg) !Interactive.ast !Interactive.env in diff --git a/src/jib/anf.ml b/src/jib/anf.ml index 5165904d..dbbb10e0 100644 --- a/src/jib/anf.ml +++ b/src/jib/anf.ml @@ -670,7 +670,7 @@ let rec anf (E_aux (e_aux, ((l, _) as exp_annot)) as exp) = let aexp2 = anf exp2 in let aval1, wrap1 = to_aval aexp1 in let aval2, wrap2 = to_aval aexp2 in - wrap1 (wrap2 (mk_aexp (AE_app (mk_id "cons", [aval1; aval2], unit_typ)))) + wrap1 (wrap2 (mk_aexp (AE_app (mk_id "sail_cons", [aval1; aval2], unit_typ)))) | E_id id -> let lvar = Env.lookup_id id (env_of exp) in diff --git a/src/jib/c_backend.ml b/src/jib/c_backend.ml index 2c9c11ee..52061444 100644 --- a/src/jib/c_backend.ml +++ b/src/jib/c_backend.ml @@ -397,14 +397,14 @@ let analyze_primop' ctx id args typ = c_debug (lazy ("Analyzing primop " ^ extern ^ "(" ^ Util.string_of_list ", " (fun aval -> Pretty_print_sail.to_string (pp_aval aval)) args ^ ")")); match extern, args with - | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + | "eq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> begin match cval_ctyp v1 with | CT_fbits _ | CT_sbits _ -> AE_val (AV_cval (V_call (Eq, [v1; v2]), typ)) | _ -> no_change end - | "neq_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + | "neq_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> begin match cval_ctyp v1 with | CT_fbits _ | CT_sbits _ -> AE_val (AV_cval (V_call (Neq, [v1; v2]), typ)) @@ -461,19 +461,19 @@ let analyze_primop' ctx id args typ = | "not_bits", [AV_cval (v, _)] -> AE_val (AV_cval (V_call (Bvnot, [v]), typ)) - | "add_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + | "add_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> AE_val (AV_cval (V_call (Bvadd, [v1; v2]), typ)) - | "sub_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + | "sub_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> AE_val (AV_cval (V_call (Bvsub, [v1; v2]), typ)) - | "and_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + | "and_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> AE_val (AV_cval (V_call (Bvand, [v1; v2]), typ)) - | "or_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + | "or_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> AE_val (AV_cval (V_call (Bvor, [v1; v2]), typ)) - | "xor_bits", [AV_cval (v1, _); AV_cval (v2, _)] -> + | "xor_bits", [AV_cval (v1, _); AV_cval (v2, _)] when ctyp_equal (cval_ctyp v1) (cval_ctyp v2) -> AE_val (AV_cval (V_call (Bvxor, [v1; v2]), typ)) | "vector_subrange", [AV_cval (vec, _); AV_cval (f, _); AV_cval (t, _)] -> @@ -2185,7 +2185,7 @@ let compile_ast env output_chan c_includes ast = " CREATE(zexception)(current_exception);" ], [ " KILL(zexception)(current_exception);"; " free(current_exception);"; - " if (have_exception) fprintf(stderr, \"Exiting due to uncaught exception\\n\");" ]) + " if (have_exception) {fprintf(stderr, \"Exiting due to uncaught exception\\n\"); exit(EXIT_FAILURE);}" ]) in let letbind_initializers = @@ -2230,9 +2230,9 @@ let compile_ast env output_chan c_includes ast = @ letbind_finalizers @ List.concat (List.map (fun r -> snd (register_init_clear r)) regs) @ finish cdefs + @ [ " cleanup_rts();" ] @ snd exn_boilerplate - @ [ " cleanup_rts();"; - "}" ] )) + @ [ "}" ] )) in let model_default_main = separate hardline (List.map string diff --git a/src/jib/jib_compile.ml b/src/jib/jib_compile.ml index 0518b9c5..5e49af7f 100644 --- a/src/jib/jib_compile.ml +++ b/src/jib/jib_compile.ml @@ -193,8 +193,10 @@ let rec compile_aval l ctx = function let ctyp = cval_ctyp cval in let ctyp' = ctyp_of_typ ctx typ in if not (ctyp_equal ctyp ctyp') then - raise (Reporting.err_unreachable l __POS__ (string_of_ctyp ctyp ^ " != " ^ string_of_ctyp ctyp')); - [], cval, [] + let gs = ngensym () in + [iinit ctyp' gs cval], V_id (gs, ctyp'), [iclear ctyp' gs] + else + [], cval, [] | AV_id (id, typ) -> begin @@ -1568,8 +1570,9 @@ let sort_ctype_defs cdefs = let compile_ast ctx (Defs defs) = let assert_vs = Initial_check.extern_of_string (mk_id "sail_assert") "(bool, string) -> unit" in let exit_vs = Initial_check.extern_of_string (mk_id "sail_exit") "unit -> unit" in + let cons_vs = Initial_check.extern_of_string (mk_id "sail_cons") "forall ('a : Type). ('a, list('a)) -> list('a)" in - let ctx = { ctx with tc_env = snd (Type_error.check ctx.tc_env (Defs [assert_vs; exit_vs])) } in + let ctx = { ctx with tc_env = snd (Type_error.check ctx.tc_env (Defs [assert_vs; exit_vs; cons_vs])) } in if !opt_memo_cache then (try diff --git a/src/jib/jib_smt.ml b/src/jib/jib_smt.ml index 07cab66b..9fb77055 100644 --- a/src/jib/jib_smt.ml +++ b/src/jib/jib_smt.ml @@ -418,7 +418,7 @@ let unsigned_size ?checked:(checked=true) ctx n m smt = else if n > m then Fn ("concat", [bvzero (n - m); smt]) else - failwith "bad arguments to unsigned_size" + Extract (n - 1, 0, smt) let int_size ctx = function | CT_constant n -> required_width n @@ -536,6 +536,11 @@ let builtin_min_int ctx v1 v2 ret_ctyp = smt1, smt2) +let bvmask ctx len = + let all_ones = bvones (lbits_size ctx) in + let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); len]) in + bvnot (bvshl all_ones shift) + let builtin_eq_bits ctx v1 v2 = match cval_ctyp v1, cval_ctyp v2 with | CT_fbits (n, _), CT_fbits (m, _) -> @@ -545,15 +550,20 @@ let builtin_eq_bits ctx v1 v2 = Fn ("=", [smt1; smt2]) | CT_lbits _, CT_lbits _ -> - Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) + let len1 = Fn ("len", [smt_cval ctx v1]) in + let contents1 = Fn ("contents", [smt_cval ctx v1]) in + let len2 = Fn ("len", [smt_cval ctx v1]) in + let contents2 = Fn ("contents", [smt_cval ctx v1]) in + Fn ("and", [Fn ("=", [len1; len2]); + Fn ("=", [Fn ("bvand", [bvmask ctx len1; contents1]); Fn ("bvand", [bvmask ctx len2; contents2])])]) | CT_lbits _, CT_fbits (n, _) -> - let smt2 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v2) in - Fn ("=", [Fn ("contents", [smt_cval ctx v1]); smt2]) + let smt1 = unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [smt_cval ctx v1])) in + Fn ("=", [smt1; smt_cval ctx v2]) | CT_fbits (n, _), CT_lbits _ -> - let smt1 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v1) in - Fn ("=", [smt1; Fn ("contents", [smt_cval ctx v2])]) + let smt2 = unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [smt_cval ctx v2])) in + Fn ("=", [smt_cval ctx v1; smt2]) | _ -> builtin_type_error ctx "eq_bits" [v1; v2] None @@ -566,11 +576,6 @@ let builtin_zeros ctx v ret_ctyp = Fn ("Bits", [extract (ctx.lbits_index - 1) 0 (smt_cval ctx v); bvzero (lbits_size ctx)]) | _ -> builtin_type_error ctx "zeros" [v] (Some ret_ctyp) -let bvmask ctx len = - let all_ones = bvones (lbits_size ctx) in - let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); len]) in - bvnot (bvshl all_ones shift) - let builtin_ones ctx cval = function | CT_fbits (n, _) -> bvones n | CT_lbits _ -> @@ -691,12 +696,23 @@ let builtin_append ctx v1 v2 ret_ctyp = Fn ("Bits", [bvadd (bvint ctx.lbits_index (Big_int.of_int n)) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))]) + | CT_lbits _, CT_fbits (n, _), CT_fbits (m, _) -> + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Extract (m - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2])) + | CT_lbits _, CT_fbits (n, _), CT_lbits _ -> let smt1 = smt_cval ctx v1 in let smt2 = smt_cval ctx v2 in Fn ("Bits", [bvadd (bvint ctx.lbits_index (Big_int.of_int n)) (Fn ("len", [smt1])); Extract (lbits_size ctx - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2]))]) + | CT_fbits (n, _), CT_fbits (m, _), CT_lbits _ -> + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int (n + m)); + unsigned_size ctx (lbits_size ctx) (n + m) (Fn ("concat", [smt1; smt2]))]) + | CT_lbits _, CT_lbits _, CT_lbits _ -> let smt1 = smt_cval ctx v1 in let smt2 = smt_cval ctx v2 in @@ -704,6 +720,13 @@ let builtin_append ctx v1 v2 ret_ctyp = let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in Fn ("Bits", [bvadd (Fn ("len", [smt1])) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))]) + | CT_lbits _, CT_lbits _, CT_fbits (n, _) -> + let smt1 = smt_cval ctx v1 in + let smt2 = smt_cval ctx v2 in + let x = Fn ("contents", [smt1]) in + let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in + unsigned_size ctx n (lbits_size ctx) (bvor (bvshl x shift) (Fn ("contents", [smt2]))) + | _ -> builtin_type_error ctx "append" [v1; v2] (Some ret_ctyp) let builtin_length ctx v ret_ctyp = @@ -729,6 +752,9 @@ let builtin_vector_subrange ctx vec i j ret_ctyp = | CT_fbits (n, _), CT_constant i, CT_constant j -> Extract (Big_int.to_int i, Big_int.to_int j, smt_cval ctx vec) + | CT_lbits _, CT_constant i, CT_constant j -> + Extract (Big_int.to_int i, Big_int.to_int j, Fn ("contents", [smt_cval ctx vec])) + | _ -> builtin_type_error ctx "vector_subrange" [vec; i; j] (Some ret_ctyp) let builtin_vector_access ctx vec i ret_ctyp = @@ -783,9 +809,12 @@ let builtin_unsigned ctx v ret_ctyp = let builtin_signed ctx v ret_ctyp = match cval_ctyp v, ret_ctyp with - | CT_fbits (n, _), CT_fint m when m > n -> + | CT_fbits (n, _), CT_fint m when m >= n -> SignExtend(m - n, smt_cval ctx v) + | CT_fbits (n, _), CT_lint -> + SignExtend(ctx.lint_size - n, smt_cval ctx v) + | ctyp, _ -> builtin_type_error ctx "signed" [v] (Some ret_ctyp) let builtin_add_bits ctx v1 v2 ret_ctyp = @@ -1141,6 +1170,8 @@ let rec smt_conversion ctx from_ctyp to_ctyp x = bvint ctx.lint_size c | CT_fint sz, CT_lint -> force_size ctx ctx.lint_size sz x + | CT_lbits _, CT_fbits (n, _) -> + unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [x])) | _, _ -> failwith (Printf.sprintf "Cannot perform conversion from %s to %s" (string_of_ctyp from_ctyp) (string_of_ctyp to_ctyp)) let define_const ctx id ctyp exp = Define_const (zencode_name id, smt_ctyp ctx ctyp, exp) @@ -2030,6 +2061,22 @@ let compile env ast = let rmap = build_register_map CTMap.empty cdefs in cdefs, { (initial_ctx ()) with tc_env = env; register_map = rmap; ast = ast } +let serialize_smt_model file env ast = + let cdefs, ctx = compile env ast in + let out_chan = open_out file in + Marshal.to_channel out_chan cdefs []; + Marshal.to_channel out_chan (Type_check.Env.set_prover None ctx.tc_env) []; + Marshal.to_channel out_chan ctx.register_map []; + close_out out_chan + +let deserialize_smt_model file = + let in_chan = open_in file in + let cdefs = (Marshal.from_channel in_chan : cdef list) in + let env = (Marshal.from_channel in_chan : Type_check.env) in + let rmap = (Marshal.from_channel in_chan : id list CTMap.t) in + close_in in_chan; + (cdefs, { (initial_ctx ()) with tc_env = env; register_map = rmap }) + let generate_smt props name_file env ast = try let cdefs, ctx = compile env ast in diff --git a/src/jib/jib_smt.mli b/src/jib/jib_smt.mli index 2680f937..cdaf7e39 100644 --- a/src/jib/jib_smt.mli +++ b/src/jib/jib_smt.mli @@ -139,6 +139,12 @@ module Make_optimizer(S : Sequence) : sig val optimize : smt_def Stack.t -> smt_def S.t end +val serialize_smt_model : + string -> Type_check.Env.t -> Type_check.tannot defs -> unit + +val deserialize_smt_model : + string -> cdef list * ctx + (** Generate SMT for all the $property and $counterexample pragmas in an AST, and write it to appropriately named files. *) val generate_smt : diff --git a/src/libsail.mllib b/src/libsail.mllib index fb3d1264..2d1f568f 100644 --- a/src/libsail.mllib +++ b/src/libsail.mllib @@ -52,6 +52,7 @@ Sail Sail2_values Sail_lib Scattered +Smtlib Spec_analysis Specialize State diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 93579420..7a43ca6c 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -439,7 +439,7 @@ type split = | VarSplit of (tannot pat * (* pattern for this case *) (id * tannot Ast.exp) list * (* substitutions for arguments *) pat_choice list * (* optional locations of constraints/case expressions to reduce *) - (kid * nexp) list) (* substitutions for type variables *) + nexp KBindings.t) (* substitutions for type variables *) list | ConstrSplit of (tannot pat * nexp KBindings.t) list @@ -620,7 +620,7 @@ let apply_pat_choices choices = e_assert = rewrite_assert; e_case = rewrite_case } -let split_defs all_errors splits env defs = +let split_defs target all_errors splits env defs = let no_errors_happened = ref true in let error_opt = if all_errors then Some no_errors_happened else None in let split_constructors (Defs defs) = @@ -651,7 +651,7 @@ let split_defs all_errors splits env defs = let subst_exp ref_vars substs ksubsts exp = let substs = bindings_from_list substs, ksubsts in - fst (Constant_propagation.const_prop defs ref_vars substs Bindings.empty exp) + fst (Constant_propagation.const_prop target defs ref_vars substs Bindings.empty exp) in (* Split a variable pattern into every possible value *) @@ -672,26 +672,26 @@ let split_defs all_errors splits env defs = in if all_errors then (no_errors_happened := false; print_error error; - [P_aux (P_id var,(pat_l,annot)),[],[],[]]) + [P_aux (P_id var,(pat_l,annot)),[],[],KBindings.empty]) else raise (Fatal_error error) in match ty with | Typ_id (Id_aux (Id "bool",_)) | Typ_app (Id_aux (Id "atom_bool", _), [_]) -> - [P_aux (P_lit (L_aux (L_true,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_true,new_l)),(new_l,annot))],[],[]; - P_aux (P_lit (L_aux (L_false,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_false,new_l)),(new_l,annot))],[],[]] + [P_aux (P_lit (L_aux (L_true,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_true,new_l)),(new_l,annot))],[],KBindings.empty; + P_aux (P_lit (L_aux (L_false,new_l)),(l,annot)),[var, E_aux (E_lit (L_aux (L_false,new_l)),(new_l,annot))],[],KBindings.empty] | Typ_id id -> (try (* enumerations *) let ns = Env.get_enum id env in List.map (fun n -> (P_aux (P_id (renew_id n),(l,annot)), - [var,E_aux (E_id (renew_id n),(new_l,annot))],[],[])) ns + [var,E_aux (E_id (renew_id n),(new_l,annot))],[],KBindings.empty)) ns with Type_error _ -> match id with | Id_aux (Id "bit",_) -> List.map (fun b -> P_aux (P_lit (L_aux (b,new_l)),(l,annot)), - [var,E_aux (E_lit (L_aux (b,new_l)),(new_l, annot))],[],[]) + [var,E_aux (E_lit (L_aux (b,new_l)),(new_l, annot))],[],KBindings.empty) [L_zero; L_one] | _ -> cannot ("don't know about type " ^ string_of_id id)) @@ -705,7 +705,7 @@ let split_defs all_errors splits env defs = let lits = make_vectors sz in List.map (fun lit -> P_aux (P_lit lit,(l,annot)), - [var,E_aux (E_lit lit,(new_l,annot))],[],[]) lits + [var,E_aux (E_lit lit,(new_l,annot))],[],KBindings.empty) lits else cannot ("bitvector length outside limit, " ^ string_of_nexp len) | _ -> @@ -718,7 +718,8 @@ let split_defs all_errors splits env defs = let lit = L_aux (L_num i,new_l) in P_aux (P_lit lit,(l,annot)), [var,E_aux (E_lit lit,(new_l,annot))],[], - match kid with None -> [] | Some k -> [(k,nconstant i)] + match kid with None -> KBindings.empty + | Some k -> KBindings.singleton k (nconstant i) in match value with | Nexp_constant i -> [mk_lit None i] @@ -761,18 +762,25 @@ let split_defs all_errors splits env defs = | h::t -> let t' = match list f t with - | None -> [t,[],[],[]] + | None -> [t,[],[],KBindings.empty] | Some t' -> t' in let h' = match f h with - | None -> [h,[],[],[]] + | None -> [h,[],[],KBindings.empty] | Some ps -> ps in + let merge (h,hsubs,hpchoices,hksubs) (t,tsubs,tpchoices,tksubs) = + if KBindings.for_all (fun kid nexp -> + match KBindings.find_opt kid tksubs with + | None -> true + | Some nexp' -> Nexp.compare nexp nexp' == 0) hksubs + then Some (h::t, hsubs@tsubs, hpchoices@tpchoices, + KBindings.union (fun k a _ -> Some a) hksubs tksubs) + else None + in Some (List.concat - (List.map (fun (h,hsubs,hpchoices,hksubs) -> - List.map (fun (t,tsubs,tpchoices,tksubs) -> - (h::t, hsubs@tsubs, hpchoices@tpchoices, hksubs@tksubs)) t') h')) + (List.map (fun h -> Util.map_filter (merge h) t') h')) in let rec spl (P_aux (p,(l,annot))) = let relist f ctx ps = @@ -784,6 +792,12 @@ let split_defs all_errors splits env defs = optmap (spl p) (fun ps -> List.map (fun (p,sub,pchoices,ksub) -> (P_aux (f p,(l,annot)), sub, pchoices, ksub)) ps) in + let re2 f ctx p1 p2 = + (* Todo: I am not proud of this abuse of relist - but creating a special + * version of re just for two entries did not seem worth it + *) + relist f (fun [p1'; p2'] -> ctx p1' p2') [p1; p2] + in let fpat (FP_aux ((FP_Fpat (id,p),annot))) = optmap (spl p) (fun ps -> List.map (fun (p,sub,pchoices,ksub) -> FP_aux (FP_Fpat (id,p), annot), sub, pchoices, ksub) ps) @@ -793,10 +807,7 @@ let split_defs all_errors splits env defs = | P_wild -> None | P_or (p1, p2) -> - (* Todo: I am not proud of this abuse of relist - but creating a special - * version of re just for two entries did not seem worth it - *) - relist spl (fun [p1'; p2'] -> P_or (p1', p2')) [p1; p2] + re2 spl (fun p1' p2' -> P_or (p1', p2')) p1 p2 | P_not p -> (* todo: not sure that I can't split - but can't figure out how at * the moment *) @@ -815,10 +826,10 @@ let split_defs all_errors splits env defs = let kids = Spec_analysis.equal_kids (env_of_pat p') kid in Some (List.map (fun (p,sub,pchoices,ksub) -> P_aux (P_var (p,tp),(l,annot)), sub, pchoices, - List.concat - (List.map - (fun (k,nexp) -> if KidSet.mem k kids then [(kid,nexp);(k,nexp)] else [(k,nexp)]) - ksub)) ps)) + match List.find_opt (fun k -> KBindings.mem k ksub) (KidSet.elements kids) with + | None -> ksub + | Some k -> KBindings.add kid (KBindings.find k ksub) ksub + ) ps)) | P_var (p',tp) -> re (fun p -> P_var (p,tp)) p' | P_id id -> (match id_match id with @@ -849,19 +860,19 @@ let split_defs all_errors splits env defs = (Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (Nexp_var var,_)),_)]),_) -> - [var,nconstant j] - | _ -> [] + KBindings.singleton var (nconstant j) + | _ -> KBindings.empty in p,[id,E_aux (E_lit lit,(Generated pl,pannot))],[l,(i,max,[])],kid_subst | P_aux (p',(pl,pannot)) when lit_like p' -> - p,[id,to_exp p],[l,(i,max,[])],[] + p,[id,to_exp p],[l,(i,max,[])],KBindings.empty | _ -> let p',subst = freshen_pat_bindings p in match p' with | P_aux (P_wild,_) -> - P_aux (P_id id,(l,annot)),[],[l,(i,max,subst)],[] + P_aux (P_id id,(l,annot)),[],[l,(i,max,subst)],KBindings.empty | _ -> - P_aux (P_as (p',id),(l,annot)),[],[l,(i,max,subst)],[]) + P_aux (P_as (p',id),(l,annot)),[],[l,(i,max,subst)],KBindings.empty) pats) ) | P_app (id,ps) -> @@ -879,14 +890,7 @@ let split_defs all_errors splits env defs = | P_list ps -> relist spl (fun ps -> P_list ps) ps | P_cons (p1,p2) -> - match spl p1, spl p2 with - | None, None -> None - | p1', p2' -> - let p1' = match p1' with None -> [p1,[],[],[]] | Some p1' -> p1' in - let p2' = match p2' with None -> [p2,[],[],[]] | Some p2' -> p2' in - let ps = List.map (fun (p1',subs1,pchoices1,ksub1) -> List.map (fun (p2',subs2,pchoices2,ksub2) -> - P_aux (P_cons (p1',p2'),(l,annot)),subs1@subs2,pchoices1@pchoices2,ksub1@ksub2) p2') p1' in - Some (List.concat ps) + re2 spl (fun p1' p2' -> P_cons (p1', p2')) p1 p2 in spl p in @@ -1028,7 +1032,6 @@ let split_defs all_errors splits env defs = | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then List.map (fun (pat',substs,pchoices,ksubsts) -> - let ksubsts = kbindings_from_list ksubsts in let exp' = Spec_analysis.nexp_subst_exp ksubsts e in let exp' = subst_exp ref_vars substs ksubsts exp' in let exp' = apply_pat_choices pchoices exp' in @@ -1049,7 +1052,6 @@ let split_defs all_errors splits env defs = | VarSplit patsubsts -> if check_split_size patsubsts (pat_loc p) then List.map (fun (pat',substs,pchoices,ksubsts) -> - let ksubsts = kbindings_from_list ksubsts in let exp1' = Spec_analysis.nexp_subst_exp ksubsts e1 in let exp1' = subst_exp ref_vars substs ksubsts exp1' in let exp1' = apply_pat_choices pchoices exp1' in @@ -1866,7 +1868,9 @@ let rec deps_of_nc kid_deps (NC_aux (nc,l)) = match nc with | NC_equal (nexp1,nexp2) | NC_bounded_ge (nexp1,nexp2) + | NC_bounded_gt (nexp1,nexp2) | NC_bounded_le (nexp1,nexp2) + | NC_bounded_lt (nexp1,nexp2) | NC_not_equal (nexp1,nexp2) -> dmerge (deps_of_nexp l kid_deps [] nexp1) (deps_of_nexp l kid_deps [] nexp2) | NC_set (kid,_) -> @@ -1980,8 +1984,11 @@ let simplify_size_nexp env typ_env (Nexp_aux (ne,l) as nexp) = | Nexp_constant _ -> nexp | _ -> match List.find is_equal env.top_kids with - | kid -> Nexp_aux (Nexp_var kid,Generated l) - | exception Not_found -> nexp + | kid -> Nexp_aux (Nexp_var kid, Generated l) + | exception Not_found -> + match KBindings.find_first_opt is_equal (Env.get_typ_vars typ_env) with + | Some (kid,_) -> Nexp_aux (Nexp_var kid, Generated l) + | None -> nexp let simplify_size_typ_arg env typ_env = function | A_aux (A_nexp nexp, l) -> A_aux (A_nexp (simplify_size_nexp env typ_env nexp), l) @@ -2513,14 +2520,11 @@ let rec sets_from_assert e = | None -> KBindings.empty) | _ -> KBindings.empty in - match destruct_atom_bool (env_of e) (typ_of e) with - | Some nc -> sets_from_nc nc - | None -> - match e with - | E_aux (E_app (Id_aux (Id "and_bool",_),[e1;e2]),_) -> - merge_set_asserts_by_kid (sets_from_assert e1) (sets_from_assert e2) - | E_aux (E_constraint nc,_) -> sets_from_nc nc - | _ -> set_from_or_exps e + match e with + | E_aux (E_app (Id_aux (Id "and_bool",_),[e1;e2]),_) -> + merge_set_asserts_by_kid (sets_from_assert e1) (sets_from_assert e2) + | E_aux (E_constraint nc,_) -> sets_from_nc nc + | _ -> set_from_or_exps e (* Find all the easily reached set assertions in a function body, to use as case splits. Note that this should be mirrored in stop_at_false_assertions, @@ -2545,7 +2549,7 @@ let print_set_assertions set_assertions = else begin print_endline "Top-level set assertions found:"; KBindings.iter (fun k (l,is) -> - print_endline (string_of_kid k ^ " " ^ + print_endline (string_of_kid k ^ " @ " ^ simple_string_of_loc l ^ " " ^ String.concat "," (List.map Big_int.to_string is))) set_assertions end @@ -2746,11 +2750,12 @@ let rec rewrite_app env typ (id,args) = is_id env (Id "Zeros") id || is_id env (Id "zeros") id || is_id env (Id "sail_zeros") id in - let is_ones = is_id env (Id "Ones") in + let is_ones id = is_id env (Id "Ones") id || is_id env (Id "ones") id || + is_id env (Id "sail_ones") id in let is_zero_extend = is_id env (Id "ZeroExtend") id || is_id env (Id "zero_extend") id || is_id env (Id "sail_zero_extend") id || - is_id env (Id "mips_zero_extend") id + is_id env (Id "mips_zero_extend") id || is_id env (Id "EXTZ") id in let is_truncate = is_id env (Id "truncate") id in let mk_exp e = E_aux (e, (Unknown, empty_tannot)) in @@ -2971,13 +2976,17 @@ let rec rewrite_app env typ (id,args) = match List.filter (fun arg -> not (is_number (typ_of arg))) args with | [E_aux (E_app (append1, [E_aux (E_app (subrange1, [vector1; start1; end1]), _); - E_aux (E_app (zeros1, [len1]),_)]),_)] + (E_aux (E_app (zeros1, [len1]),_) | + E_aux (E_cast (_,E_aux (E_app (zeros1, [len1]),_)),_)) + ]),_)] when is_subrange subrange1 && is_zeros zeros1 && is_append append1 -> try_cast_to_typ (rewrap (E_app (mk_id "place_subrange", length_arg @ [vector1; start1; end1; len1]))) | [E_aux (E_app (append1, [vector1; - E_aux (E_app (zeros1, [length2]),_)]),_)] + (E_aux (E_app (zeros1, [length2]),_) | + E_aux (E_cast (_, E_aux (E_app (zeros1, [length2]),_)),_)) + ]),_)] when is_constant_vec_typ env (typ_of vector1) && is_zeros zeros1 && is_append append1 -> let (vector1, start1, length1) = match vector1 with @@ -3025,8 +3034,19 @@ let rec rewrite_app env typ (id,args) = try_cast_to_typ (rewrap (E_app (mk_id "sext_slice", length_arg @ [vector1; start1; length1]))) | [E_aux (E_app (append, + [E_aux (E_app (subrange1, [vector1; start1; end1]), _); + (E_aux (E_app (zeros2, [len2]), _) | + E_aux (E_cast (_, E_aux (E_app (zeros2, [len2]), _)),_)) + ]), _)] + when is_append append && is_subrange subrange1 && is_zeros zeros2 && + not (is_constant len2) -> + E_app (mk_id "place_subrange_signed", length_arg @ [vector1; start1; end1; len2]) + + | [E_aux (E_app (append, [E_aux (E_app (slice1, [vector1; start1; len1]), _); - E_aux (E_app (zeros2, [len2]), _)]), _)] + (E_aux (E_app (zeros2, [len2]), _) | + E_aux (E_cast (_, E_aux (E_app (zeros2, [len2]), _)),_)) + ]), _)] when is_append append && is_slice slice1 && is_zeros zeros2 && not (is_constant len1 && is_constant len2) -> E_app (mk_id "place_slice_signed", length_arg @ [vector1; start1; len1; len2]) @@ -3085,6 +3105,18 @@ let rewrite_aux = function E_aux (rewrite_app env ty (id,args), (l, tannot)) | None -> E_aux (E_app (id, args), (l, tannot)) end + | E_assign ( + LEXP_aux (LEXP_vector_range (LEXP_aux (LEXP_id id1,(l_id1,_)), start1, end1),_), + E_aux (E_app (subrange2, [vector2; start2; end2]),(l_assign,_))), + annot + when is_id (env_of_annot annot) (Id "vector_subrange") subrange2 && + not (is_constant_range (start1, end1)) -> + E_aux (E_assign (LEXP_aux (LEXP_id id1,(l_id1,empty_tannot)), + E_aux (E_app (mk_id "vector_update_subrange_from_subrange", [ + E_aux (E_id id1,(Generated l_id1,empty_tannot)); + start1; end1; + vector2; start2; end2]),(Unknown,empty_tannot))), + (l_assign, empty_tannot)) | exp,annot -> E_aux (exp,annot) let mono_rewrite defs = @@ -3196,6 +3228,7 @@ let make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ check_for_spec env cast_name; let src_ann = mk_tannot env src_typ no_effect in let tar_ann = mk_tannot env target_typ no_effect in + let asg_ann = mk_tannot env unit_typ no_effect in match src_typ' with (* Simple case with just the bitvector; don't need to pull apart value *) | Typ_aux (Typ_app _,_) -> @@ -3205,9 +3238,15 @@ let make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ E_aux (E_app (Id_aux (Id cast_name,genunk), [E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann))),(genunk,tar_ann)), exp),(genunk,exp_ann))), + (fun var -> + [E_aux (E_assign (LEXP_aux (LEXP_cast (one_target_typ, var),(genunk,tar_ann)), + E_aux (E_app (Id_aux (Id cast_name,genunk), + [E_aux (E_id var,(genunk,src_ann))]),(genunk,tar_ann) + )),(genunk,asg_ann))]), (fun (E_aux (_,(exp_l,exp_ann)) as exp) -> E_aux (E_cast (one_target_typ, E_aux (E_app (Id_aux (Id cast_name, genunk), [exp]), (Generated exp_l,tar_ann))), + (Generated exp_l,tar_ann))) | _ -> (fun var exp -> @@ -3215,17 +3254,58 @@ let make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id var,(genunk,src_ann))),(genunk,src_ann)), E_aux (E_let (LB_aux (LB_val (P_aux (P_id var,(genunk,tar_ann)),e'),(genunk,tar_ann)), exp),(genunk,exp_ann))),(genunk,exp_ann))), + (fun var -> + [E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id var,(genunk,src_ann))),(genunk,src_ann)), + E_aux (E_assign (LEXP_aux (LEXP_cast (one_target_typ, var),(genunk,tar_ann)), + e'),(genunk,asg_ann))),(genunk,asg_ann))]), (fun (E_aux (_,(exp_l,exp_ann)) as exp) -> E_aux (E_let (LB_aux (LB_val (pat, exp),(Generated exp_l,exp_ann)), e'),(Generated exp_l,tar_ann))) end - | None -> (fun _ e -> e),(fun e -> e) + | None -> (fun _ e -> e),(fun _ -> []),(fun e -> e) +let make_bitvector_cast_let cast_name top_env env quant_kids src_typ target_typ = + let f,_,_ = make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ + in f +let make_bitvector_cast_assign cast_name top_env env quant_kids src_typ target_typ = + let _,f,_ = make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ + in f +let make_bitvector_cast_cast cast_name top_env env quant_kids src_typ target_typ = + let _,_,f = make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ + in f + +let ids_in_exp exp = + let open Rewriter in + fold_exp { + (pure_exp_alg IdSet.empty IdSet.union) with + e_id = IdSet.singleton; + lEXP_id = IdSet.singleton; + lEXP_memory = (fun (id,s) -> List.fold_left IdSet.union (IdSet.singleton id) s); + lEXP_cast = (fun (_,id) -> IdSet.singleton id) + } exp -(* TODO: bound vars *) let make_bitvector_env_casts env quant_kids (kid,i) exp = - let mk_cast var typ exp = (fst (make_bitvector_cast_fns "bitvector_cast_in" env env quant_kids typ (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in + let mk_cast var typ exp = (make_bitvector_cast_let "bitvector_cast_in" env env quant_kids typ (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ)) var exp in + let mk_assign_in var typ = + make_bitvector_cast_assign "bitvector_cast_in" env env quant_kids typ + (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ) var + in + let mk_assign_out var typ = + make_bitvector_cast_assign "bitvector_cast_out" env env quant_kids + (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ) typ var + in let locals = Env.get_locals env in + let used_ids = ids_in_exp exp in + let locals = Bindings.filter (fun id _ -> IdSet.mem id used_ids) locals in + let immutables,mutables = Bindings.partition (fun _ (mut,_) -> mut = Immutable) locals in + let assigns_in = Bindings.fold (fun var (_,typ) acc -> mk_assign_in var typ @ acc) mutables [] in + let assigns_out = Bindings.fold (fun var (_,typ) acc -> mk_assign_out var typ @ acc) mutables [] in + let exp = match assigns_in, exp with + | [], _ -> exp + | _::_, E_aux (E_block es,ann) -> E_aux (E_block (assigns_in @ es @ assigns_out),ann) + | _::_, E_aux (_,(l,ann)) -> + E_aux (E_block (assigns_in @ [exp] @ assigns_out), (Generated l,ann)) + in Bindings.fold (fun var (mut,typ) exp -> - if mut = Immutable then mk_cast var typ exp else exp) locals exp + if mut = Immutable then mk_cast var typ exp else exp) immutables exp let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp = if alpha_equivalent cast_env typ target_typ then exp else @@ -3269,7 +3349,7 @@ let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp = let tgt_arg_typ = infer_arg_typ (env_of exp) f l target_typ in E_aux (E_app (f,[aux arg (src_arg_typ, tgt_arg_typ)]),(l,ann)) | _ -> - (snd (make_bitvector_cast_fns cast_name cast_env (env_of exp) quant_kids typ target_typ)) exp + (make_bitvector_cast_cast cast_name cast_env (env_of exp) quant_kids typ target_typ) exp in aux exp (typ, target_typ) @@ -3298,6 +3378,27 @@ let fill_in_type env typ = | Some n -> KBindings.add kid (nconstant n) subst)) tyvars KBindings.empty in subst_kids_typ subst typ +(* Extract the instantiations of kids resulting from an if or assert guard *) +let rec extract (E_aux (e,_)) = + match e with + | E_app (op, + ([E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_); y] | + [y; E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_)])) + when string_of_id op = "eq_int" -> + (match destruct_atom_nexp (env_of y) (typ_of y) with + | Some (Nexp_aux (Nexp_constant i,_)) -> [(kid,i)] + | _ -> []) + | E_app (op,[x;y]) + when string_of_id op = "eq_int" -> + (match destruct_atom_nexp (env_of x) (typ_of x), destruct_atom_nexp (env_of y) (typ_of y) with + | Some (Nexp_aux (Nexp_var kid,_)), Some (Nexp_aux (Nexp_constant i,_)) + | Some (Nexp_aux (Nexp_constant i,_)), Some (Nexp_aux (Nexp_var kid,_)) + -> [(kid,i)] + | _ -> []) + | E_app (op, [x;y]) when string_of_id op = "and_bool" -> + extract x @ extract y + | _ -> [] + (* TODO: top-level patterns *) (* TODO: proper environment tracking for variables. Currently we pretend that we can print the type of a variable in the top-level environment, but in @@ -3342,26 +3443,6 @@ let add_bitvector_casts (Defs defs) = | E_if (e1,e2,e3) -> let env = env_of_annot ann in let result_typ = Env.base_typ_of env (typ_of_annot ann) in - let rec extract (E_aux (e,_)) = - match e with - | E_app (op, - ([E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_); y] | - [y; E_aux (E_sizeof (Nexp_aux (Nexp_var kid,_)),_)])) - when string_of_id op = "eq_int" -> - (match destruct_atom_nexp (env_of y) (typ_of y) with - | Some (Nexp_aux (Nexp_constant i,_)) -> [(kid,i)] - | _ -> []) - | E_app (op,[x;y]) - when string_of_id op = "eq_int" -> - (match destruct_atom_nexp (env_of x) (typ_of x), destruct_atom_nexp (env_of y) (typ_of y) with - | Some (Nexp_aux (Nexp_var kid,_)), Some (Nexp_aux (Nexp_constant i,_)) - | Some (Nexp_aux (Nexp_constant i,_)), Some (Nexp_aux (Nexp_var kid,_)) - -> [(kid,i)] - | _ -> []) - | E_app (op, [x;y]) when string_of_id op = "and_bool" -> - extract x @ extract y - | _ -> [] - in let insts = extract e1 in let e2' = List.fold_left (fun body inst -> make_bitvector_env_casts env quant_kids inst body) e2 insts in @@ -3369,29 +3450,48 @@ let add_bitvector_casts (Defs defs) = KBindings.add kid (nconstant i) insts) KBindings.empty insts in let src_typ = subst_kids_typ insts result_typ in let e2' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ e2' in - E_aux (E_if (e1,e2',e3), ann) + (* Ask the type checker if only one value remains for any of kids in + the else branch. *) + let env3 = env_of e3 in + let insts3 = KBindings.fold (fun kid _ i3 -> + match Type_check.solve_unique env3 (nvar kid) with + | None -> i3 + | Some c -> (kid, c)::i3) + insts [] + in + let e3' = List.fold_left (fun body inst -> + make_bitvector_env_casts env quant_kids inst body) e3 insts3 in + let insts3 = List.fold_left (fun insts (kid,i) -> + KBindings.add kid (nconstant i) insts) KBindings.empty insts3 in + let src_typ3 = subst_kids_typ insts3 result_typ in + let e3' = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ3 result_typ e3' in + E_aux (E_if (e1,e2',e3'), ann) | E_return e' -> E_aux (E_return (make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) ret_typ e'),ann) - | E_assign (LEXP_aux (_,lexp_annot) as lexp,e') -> begin - (* The type in the lexp_annot might come from e' rather than being the - type of the storage, so ask the type checker what it really is. *) - match infer_lexp (env_of_annot lexp_annot) (strip_lexp lexp) with - | LEXP_aux (_,lexp_annot') -> - E_aux (E_assign (lexp, - make_bitvector_cast_exp "bitvector_cast_out" top_env quant_kids (fill_in_type (env_of e') (typ_of e')) - (typ_of_annot lexp_annot') e'),ann) - | exception _ -> E_aux (e,ann) - end - | E_id id -> begin - let env = env_of_annot ann in - match Env.lookup_id id env with - | Local (Mutable, vtyp) -> - make_bitvector_cast_exp "bitvector_cast_in" top_env quant_kids - (fill_in_type (env_of_annot ann) (typ_of_annot ann)) - vtyp - (E_aux (e,ann)) - | _ -> E_aux (e,ann) - end + | E_block es -> + let env = env_of_annot ann in + let result_typ = Env.base_typ_of env (typ_of_annot ann) in + let rec aux = function + | [] -> [] + | (E_aux (E_assert (assert_exp,msg),ann) as h)::t -> + let insts = extract assert_exp in + begin match insts with + | [] -> h::(aux t) + | _ -> + let t' = aux t in + let et = E_aux (E_block t',ann) in + let env = env_of h in + let et = List.fold_left (fun body inst -> + make_bitvector_env_casts env quant_kids inst body) et insts in + let insts = List.fold_left (fun insts (kid,i) -> + KBindings.add kid (nconstant i) insts) KBindings.empty insts in + let src_typ = subst_kids_typ insts result_typ in + let et = make_bitvector_cast_exp "bitvector_cast_out" env quant_kids src_typ result_typ et in + + [h; et] + end + | h::t -> h::(aux t) + in E_aux (E_block (aux es),ann) | _ -> E_aux (e,ann) in let open Rewriter in @@ -3439,7 +3539,7 @@ let add_bitvector_casts (Defs defs) = mk_val_spec (VS_val_spec (ts,name,[("_", "zeroExtend")],false)) in let defs = List.map mkfn (IdSet.elements !specs_required) in - check Env.empty (Defs defs) + check initial_env (Defs defs) in Defs (cast_specs @ defs) end @@ -3498,7 +3598,10 @@ let fresh_nexp_kid nexp = | Nexp_app (id,args) -> string_of_id id ^ "_" ^ String.concat "_" (List.map mangle_nexp args) in - mk_kid (mangle_nexp nexp ^ "#") + (* TODO: I'd like to add a # to distinguish it from user-provided names, but + the rewriter currently uses them as a hint that they're not printable in + types, which these are explicitly supposed to be. *) + mk_kid (mangle_nexp nexp (*^ "#"*)) let rewrite_toplevel_nexps (Defs defs) = let find_nexp env nexp_map nexp = @@ -3604,7 +3707,9 @@ let rewrite_toplevel_nexps (Defs defs) = match nc with | NC_equal (n1, n2) -> rewrap (NC_equal (aux_nexp n1, aux_nexp n2)) | NC_bounded_ge (n1, n2) -> rewrap (NC_bounded_ge (aux_nexp n1, aux_nexp n2)) + | NC_bounded_gt (n1, n2) -> rewrap (NC_bounded_gt (aux_nexp n1, aux_nexp n2)) | NC_bounded_le (n1, n2) -> rewrap (NC_bounded_le (aux_nexp n1, aux_nexp n2)) + | NC_bounded_lt (n1, n2) -> rewrap (NC_bounded_lt (aux_nexp n1, aux_nexp n2)) | NC_not_equal (n1, n2) -> rewrap (NC_not_equal (aux_nexp n1, aux_nexp n2)) | NC_or (nc1, nc2) -> rewrap (NC_or (aux_nconstraint nc1, aux_nconstraint nc2)) | NC_and (nc1, nc2) -> rewrap (NC_and (aux_nconstraint nc1, aux_nconstraint nc2)) @@ -3684,7 +3789,7 @@ let recheck defs = let mono_rewrites = MonoRewrites.mono_rewrite -let monomorphise opts splits defs = +let monomorphise target opts splits defs = let defs, env = Type_check.check Type_check.initial_env defs in let ok_analysis, new_splits, extra_splits = if opts.auto @@ -3701,7 +3806,7 @@ let monomorphise opts splits defs = then () else raise (Reporting.err_general Unknown "Unable to monomorphise program") in - let ok_split, defs = split_defs opts.all_split_errors splits env defs in + let ok_split, defs = split_defs target opts.all_split_errors splits env defs in let () = if (ok_analysis && ok_extras && ok_split) || opts.continue_anyway then () else raise (Reporting.err_general Unknown "Unable to monomorphise program") diff --git a/src/monomorphise.mli b/src/monomorphise.mli index 1a82c8d0..39d89461 100644 --- a/src/monomorphise.mli +++ b/src/monomorphise.mli @@ -56,6 +56,7 @@ type options = { } val monomorphise : + string -> (* Target backend *) options -> ((string * int) * string) list -> (* List of splits from the command line *) Type_check.tannot Ast.defs -> diff --git a/src/ocaml_backend.ml b/src/ocaml_backend.ml index 28ce43d3..618de5e6 100644 --- a/src/ocaml_backend.ml +++ b/src/ocaml_backend.ml @@ -965,7 +965,7 @@ let ocaml_main spec sail_dir = @ [ " zinitializze_registers ();"; if !opt_trace_ocaml then " Sail_lib.opt_trace := true;" else " ();"; " Printexc.record_backtrace true;"; - " try zmain () with exn -> prerr_endline(\"Exiting due to uncaught exception:\\n\" ^ Printexc.to_string exn)\n";]) + " try zmain () with exn -> (prerr_endline(\"Exiting due to uncaught exception:\\n\" ^ Printexc.to_string exn); exit 1)\n";]) |> String.concat "\n" let ocaml_pp_defs f defs generator_types = diff --git a/src/pattern_completeness.ml b/src/pattern_completeness.ml index 3e26502d..3de0058f 100644 --- a/src/pattern_completeness.ml +++ b/src/pattern_completeness.ml @@ -286,6 +286,13 @@ let shrink_loc = function Lexing.(Parse_ast.Range (n, { n with pos_cnum = n.pos_cnum + 5 })) | l -> l +let is_complete ctx cases = + match cases_to_pats cases with + | [] -> false + | (_, pat) :: pats -> + let top_pat = List.fold_left (combine ctx) (generalize ctx pat) pats in + is_wild top_pat + let check l ctx cases = match cases_to_pats cases with | [] -> Reporting.warn "No non-guarded patterns at" (shrink_loc l) "" diff --git a/src/pattern_completeness.mli b/src/pattern_completeness.mli index 83d6d54c..3084bdf4 100644 --- a/src/pattern_completeness.mli +++ b/src/pattern_completeness.mli @@ -57,4 +57,6 @@ type ctx = variants : IdSet.t Bindings.t } +val is_complete : ctx -> 'a pexp list -> bool + val check : Parse_ast.l -> ctx -> 'a pexp list -> unit diff --git a/src/pretty_print_coq.ml b/src/pretty_print_coq.ml index f0947315..1fea72ea 100644 --- a/src/pretty_print_coq.ml +++ b/src/pretty_print_coq.ml @@ -303,7 +303,9 @@ let rec orig_nc (NC_aux (nc, l) as full_nc) = match nc with | NC_equal (nexp1, nexp2) -> rewrap (NC_equal (orig_nexp nexp1, orig_nexp nexp2)) | NC_bounded_ge (nexp1, nexp2) -> rewrap (NC_bounded_ge (orig_nexp nexp1, orig_nexp nexp2)) + | NC_bounded_gt (nexp1, nexp2) -> rewrap (NC_bounded_gt (orig_nexp nexp1, orig_nexp nexp2)) | NC_bounded_le (nexp1, nexp2) -> rewrap (NC_bounded_le (orig_nexp nexp1, orig_nexp nexp2)) + | NC_bounded_lt (nexp1, nexp2) -> rewrap (NC_bounded_lt (orig_nexp nexp1, orig_nexp nexp2)) | NC_not_equal (nexp1, nexp2) -> rewrap (NC_not_equal (orig_nexp nexp1, orig_nexp nexp2)) | NC_set (kid,s) -> rewrap (NC_set (orig_kid kid, s)) | NC_or (nc1, nc2) -> rewrap (NC_or (orig_nc nc1, orig_nc nc2)) @@ -431,7 +433,9 @@ let rec count_nc_vars (NC_aux (nc,_)) = -> KBindings.singleton kid 1 | NC_equal (n1,n2) | NC_bounded_ge (n1,n2) + | NC_bounded_gt (n1,n2) | NC_bounded_le (n1,n2) + | NC_bounded_lt (n1,n2) | NC_not_equal (n1,n2) -> merge_kid_count (count_nexp_vars n1) (count_nexp_vars n2) | NC_true | NC_false @@ -462,8 +466,12 @@ let simplify_atom_bool l kopts nc atom_nc = | NC_equal (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid | NC_bounded_ge (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid | NC_bounded_ge (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_gt (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_gt (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid | NC_bounded_le (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid | NC_bounded_le (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_lt (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_bounded_lt (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid | NC_not_equal (Nexp_aux (Nexp_var kid,_), _) when KBindings.mem kid lin_ty_vars -> Some kid | NC_not_equal (_, Nexp_aux (Nexp_var kid,_)) when KBindings.mem kid lin_ty_vars -> Some kid | NC_set (kid, _::_) when KBindings.mem kid lin_ty_vars -> Some kid @@ -738,6 +746,7 @@ and doc_arithfact ctxt env ?(exists = []) ?extra nc = (* Follows Coq precedence levels *) and doc_nc_prop ?(top = true) ctx env nc = let locals = Env.get_locals env |> Bindings.bindings in + let nc = Env.expand_constraint_synonyms env nc in let nc_id_map = List.fold_left (fun m (v,(_,Typ_aux (typ,_))) -> @@ -768,7 +777,9 @@ and doc_nc_prop ?(top = true) ctx env nc = | NC_equal (ne1, ne2) -> doc_op equals (doc_nexp ctx ne1) (doc_nexp ctx ne2) | NC_var kid -> doc_op equals (doc_nexp ctx (nvar kid)) (string "true") | NC_bounded_ge (ne1, ne2) -> doc_op (string ">=") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_bounded_gt (ne1, ne2) -> doc_op (string ">") (doc_nexp ctx ne1) (doc_nexp ctx ne2) | NC_bounded_le (ne1, ne2) -> doc_op (string "<=") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_bounded_lt (ne1, ne2) -> doc_op (string "<") (doc_nexp ctx ne1) (doc_nexp ctx ne2) | NC_not_equal (ne1, ne2) -> doc_op (string "<>") (doc_nexp ctx ne1) (doc_nexp ctx ne2) | _ -> l10 nc_full and l10 (NC_aux (nc,_) as nc_full) = @@ -790,7 +801,9 @@ and doc_nc_prop ?(top = true) ctx env nc = | NC_and _ | NC_equal _ | NC_bounded_ge _ + | NC_bounded_gt _ | NC_bounded_le _ + | NC_bounded_lt _ | NC_not_equal _ -> parens (l85 nc_full) in if top then newnc l85 nc else newnc l0 nc @@ -819,7 +832,9 @@ let rec doc_nc_exp ctx env nc = match nc with | NC_equal (ne1, ne2) -> doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) | NC_bounded_ge (ne1, ne2) -> doc_op (string ">=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_bounded_gt (ne1, ne2) -> doc_op (string ">?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) | NC_bounded_le (ne1, ne2) -> doc_op (string "<=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) + | NC_bounded_lt (ne1, ne2) -> doc_op (string "<?") (doc_nexp ctx ne1) (doc_nexp ctx ne2) | _ -> l50 nc_full and l50 (NC_aux (nc,_) as nc_full) = match nc with @@ -842,7 +857,9 @@ let rec doc_nc_exp ctx env nc = | NC_var kid -> doc_nexp ctx (nvar kid) | NC_equal _ | NC_bounded_ge _ + | NC_bounded_gt _ | NC_bounded_le _ + | NC_bounded_lt _ | NC_or _ | NC_and _ -> parens (l70 nc_full) in newnc l70 nc @@ -950,11 +967,11 @@ let quant_item_id_name ctx (QI_aux (qi,_)) = | QI_constraint nc -> None | QI_constant _ -> None -let doc_quant_item_constr ctx delimit (QI_aux (qi,_)) = +let doc_quant_item_constr ctx env delimit (QI_aux (qi,_)) = match qi with | QI_id _ -> None | QI_constant _ -> None - | QI_constraint nc -> Some (bquote ^^ braces (doc_arithfact ctx Env.empty nc)) + | QI_constraint nc -> Some (bquote ^^ braces (doc_arithfact ctx env nc)) (* At the moment these are all anonymous - when used we rely on Coq to fill them in. *) @@ -964,18 +981,18 @@ let quant_item_constr_name ctx (QI_aux (qi,_)) = | QI_constant _ -> None | QI_constraint nc -> Some underscore -let doc_typquant_items ctx delimit (TypQ_aux (tq,_)) = +let doc_typquant_items ctx env delimit (TypQ_aux (tq,_)) = match tq with | TypQ_tq qis -> separate_opt space (doc_quant_item_id ctx delimit) qis ^^ - separate_opt space (doc_quant_item_constr ctx delimit) qis + separate_opt space (doc_quant_item_constr ctx env delimit) qis | TypQ_no_forall -> empty -let doc_typquant_items_separate ctx delimit (TypQ_aux (tq,_)) = +let doc_typquant_items_separate ctx env delimit (TypQ_aux (tq,_)) = match tq with | TypQ_tq qis -> Util.map_filter (doc_quant_item_id ctx delimit) qis, - Util.map_filter (doc_quant_item_constr ctx delimit) qis + Util.map_filter (doc_quant_item_constr ctx env delimit) qis | TypQ_no_forall -> [], [] let typquant_names_separate ctx (TypQ_aux (tq,_)) = @@ -986,10 +1003,10 @@ let typquant_names_separate ctx (TypQ_aux (tq,_)) = | TypQ_no_forall -> [], [] -let doc_typquant ctx (TypQ_aux(tq,_)) typ = match tq with +let doc_typquant ctx env (TypQ_aux(tq,_)) typ = match tq with | TypQ_tq ((_ :: _) as qs) -> string "forall " ^^ separate_opt space (doc_quant_item_id ctx braces) qs ^/^ - separate_opt space (doc_quant_item_constr ctx parens) qs ^^ string ", " ^^ typ + separate_opt space (doc_quant_item_constr ctx env parens) qs ^^ string ", " ^^ typ | _ -> typ (* Produce Size type constraints for bitvector sizes when using @@ -1016,9 +1033,9 @@ let rec typeclass_nexps (Typ_aux(t,l)) = | Typ_bidir _ -> unreachable l __POS__ "Coq doesn't support bidir types" | Typ_internal_unknown -> unreachable l __POS__ "escaped Typ_internal_unknown" -let doc_typschm ctx quants (TypSchm_aux(TypSchm_ts(tq,t),_)) = - let pt = doc_typ ctx Env.empty t in - if quants then doc_typquant ctx tq pt else pt +let doc_typschm ctx env quants (TypSchm_aux(TypSchm_ts(tq,t),_)) = + let pt = doc_typ ctx env t in + if quants then doc_typquant ctx env tq pt else pt let is_ctor env id = match Env.lookup_id id env with | Enum _ -> true @@ -1944,6 +1961,10 @@ let doc_exp, doc_let = if effects then match cast_ex, outer_ex with | ExGeneral, ExNone -> string "projT1_m" ^/^ parens epp + | ExGeneral, ExGeneral -> + if alpha_equivalent env cast_typ outer_typ + then epp + else string "derive_m" ^/^ parens epp | _ -> epp else match cast_ex with | ExGeneral -> string "projT1" ^/^ parens epp @@ -2357,22 +2378,24 @@ let rec doc_range ctxt (BF_aux(r,_)) = match r with *) (* TODO: check use of empty_ctxt below *) -let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with +let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = + match td with | TD_abbrev(id,typq,A_aux (A_typ typ, _)) -> let typschm = TypSchm_aux (TypSchm_ts (typq, typ), l) in doc_op coloneq (separate space [string "Definition"; doc_id_type id; - doc_typquant_items empty_ctxt parens typq; + doc_typquant_items empty_ctxt Env.empty parens typq; colon; string "Type"]) - (doc_typschm empty_ctxt false typschm) ^^ dot + (doc_typschm empty_ctxt Env.empty false typschm) ^^ dot ^^ twice hardline | TD_abbrev(id,typq,A_aux (A_nexp nexp,_)) -> let idpp = doc_id_type id in doc_op coloneq (separate space [string "Definition"; idpp; - doc_typquant_items empty_ctxt parens typq; + doc_typquant_items empty_ctxt Env.empty parens typq; colon; string "Z"]) (doc_nexp empty_ctxt nexp) ^^ dot ^^ hardline ^^ - separate space [string "Hint Unfold"; idpp; colon; string "sail."] + separate space [string "Hint Unfold"; idpp; colon; string "sail."] ^^ + twice hardline | TD_abbrev _ -> empty (* TODO? *) | TD_bitfield _ -> empty (* TODO? *) | TD_record(id,typq,fs,_) -> @@ -2394,13 +2417,18 @@ let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with let doc_update_field (_,fid) = let idpp = fname fid in let otherfield (_,fid') = - if Id.compare fid fid' == 0 then empty else + if Id.compare fid fid' == 0 then None else let idpp = fname fid' in - separate space [semi; idpp; string ":="; idpp; string "r"] + Some (separate space [idpp; string ":="; idpp; string "r"]) in - string "Notation \"{[ r 'with' '" ^^ idpp ^^ string "' := e ]}\" := ({| " ^^ - idpp ^^ string " := e" ^^ concat (List.map otherfield fs) ^^ - space ^^ string "|})." + match fs with + | [_] -> + string "Notation \"{[ r 'with' '" ^^ idpp ^^ string "' := e ]}\" :=" ^//^ + string "{| " ^^ idpp ^^ string " := e |} (only parsing)." + | _ -> + string "Notation \"{[ r 'with' '" ^^ idpp ^^ string "' := e ]}\" := {|" ^//^ + idpp ^^ string " := e;" ^/^ separate (semi ^^ break 1) (Util.map_filter otherfield fs) ^/^ + string "|}" ^^ dot in let updates_pp = separate hardline (List.map doc_update_field fs) in let id_pp = doc_id_type id in @@ -2421,14 +2449,15 @@ let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with string ("cmp_record_field x" ^ ns ^ " y" ^ ns ^ "."))) ^^ hardline ^^ string "refine (Build_Decidable _ true _). subst. split; reflexivity." ^^ hardline ^^ - string "Defined." ^^ hardline + string "Defined." ^^ twice hardline else empty in let reset_implicits_pp = doc_reset_implicits id_pp typq in doc_op coloneq - (separate space [string "Record"; id_pp; doc_typquant_items empty_ctxt braces typq]) + (separate space [string "Record"; id_pp; doc_typquant_items empty_ctxt Env.empty braces typq]) ((*doc_typquant typq*) (braces (space ^^ align fs_doc ^^ space))) ^^ - dot ^^ hardline ^^ reset_implicits_pp ^^ hardline ^^ eq_pp ^^ updates_pp + dot ^^ hardline ^^ reset_implicits_pp ^^ hardline ^^ eq_pp ^^ updates_pp ^^ + twice hardline | TD_variant(id,typq,ar,_) -> (match id with | Id_aux ((Id "read_kind"),_) -> empty @@ -2442,14 +2471,14 @@ let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with | Id_aux ((Id "option"),_) -> empty | _ -> let id_pp = doc_id_type id in - let typ_nm = separate space [id_pp; doc_typquant_items empty_ctxt braces typq] in - let ar_doc = group (separate_map (break 1 ^^ pipe ^^ space) (doc_type_union empty_ctxt id_pp) ar) in + let typ_nm = separate space [id_pp; doc_typquant_items empty_ctxt Env.empty braces typq] in + let ar_doc = group (separate_map (break 1) (fun x -> pipe ^^ space ^^ doc_type_union empty_ctxt id_pp x) ar) in let typ_pp = (doc_op coloneq) (concat [string "Inductive"; space; typ_nm]) ((*doc_typquant typq*) ar_doc) in let reset_implicits_pp = doc_reset_implicits id_pp typq in - typ_pp ^^ dot ^^ hardline ^^ reset_implicits_pp ^^ hardline ^^ hardline) + typ_pp ^^ dot ^^ hardline ^^ reset_implicits_pp ^^ twice hardline) | TD_enum(id,enums,_) -> (match id with | Id_aux ((Id "read_kind"),_) -> empty @@ -2470,7 +2499,7 @@ let doc_typdef generic_eq_types (TD_aux(td, (l, annot))) = match td with let eq2_pp = string "Instance Decidable_eq_" ^^ id_pp ^^ space ^^ colon ^/^ string "forall (x y : " ^^ id_pp ^^ string "), Decidable (x = y) :=" ^/^ string "Decidable_eq_from_dec " ^^ id_pp ^^ string "_eq_dec." in - typ_pp ^^ dot ^^ hardline ^^ eq1_pp ^^ hardline ^^ eq2_pp ^^ hardline) + typ_pp ^^ dot ^^ hardline ^^ eq1_pp ^^ hardline ^^ eq2_pp ^^ twice hardline) let args_of_typ l env typs = let arg i typ = @@ -2530,7 +2559,8 @@ let pat_is_plain_binder env (P_aux (p,_)) = match p with | P_id id | P_typ (_,P_aux (P_id id,_)) - when not (is_enum env id) -> Some id + when not (is_enum env id) -> Some (Some id) + | P_wild -> Some None | _ -> None let demote_all_patterns env i (P_aux (p,p_annot) as pat,typ) = @@ -2538,10 +2568,14 @@ let demote_all_patterns env i (P_aux (p,p_annot) as pat,typ) = | Some id -> if Util.is_none (is_auto_decomposed_exist empty_ctxt env typ) then (pat,typ), fun e -> e - else - (P_aux (P_id id, p_annot),typ), - fun (E_aux (_,e_ann) as e) -> - E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)),p_annot),e),e_ann) + else begin + match id with + | Some id -> + (P_aux (P_id id, p_annot),typ), + fun (E_aux (_,e_ann) as e) -> + E_aux (E_let (LB_aux (LB_val (pat, E_aux (E_id id, p_annot)),p_annot),e),e_ann) + | None -> (P_aux (P_wild, p_annot),typ), fun e -> e + end | None -> let id = mk_id ("arg" ^ string_of_int i) in (* TODO: name conflicts *) (P_aux (P_id id, p_annot),typ), @@ -2616,30 +2650,42 @@ let mk_kid_renames ids_to_avoid kids = in snd (KidSet.fold check_kid kids (kids, KBindings.empty)) let merge_kids_atoms pats = - let try_eliminate (gone,map,seen) = function + let try_eliminate (acc,gone,map,seen) (pat,typ) = + let tryon maybe_id env typ = + let merge kid l = + if KidSet.mem kid seen then + let () = + Reporting.print_err l "merge_kids_atoms" + ("want to merge tyvar and argument for " ^ string_of_kid kid ^ + " but rearranging arguments isn't supported yet") in + (pat,typ)::acc,gone,map,seen + else + let pat,id = match maybe_id with + | Some id -> pat,id + (* TODO: name clashes *) + | None -> let id = id_of_kid kid in + P_aux (P_id id,match pat with P_aux (_,ann) -> ann), id + in + (pat,typ)::acc, + KidSet.add kid gone, KBindings.add kid (Some id) map, KidSet.add kid seen + in + match Type_check.destruct_atom_nexp env typ with + | Some (Nexp_aux (Nexp_var kid,l)) -> merge kid l + | _ -> + match Type_check.destruct_atom_bool env typ with + | Some (NC_aux (NC_var kid,l)) -> merge kid l + | _ -> (pat,typ)::acc,gone,map,KidSet.union seen (tyvars_of_typ typ) + in + match pat,typ with | P_aux (P_id id, ann), typ - | P_aux (P_typ (_,P_aux (P_id id, ann)),_), typ -> begin - let merge kid l = - if KidSet.mem kid seen then - let () = - Reporting.print_err l "merge_kids_atoms" - ("want to merge tyvar and argument for " ^ string_of_kid kid ^ - " but rearranging arguments isn't supported yet") in - gone,map,seen - else - KidSet.add kid gone, KBindings.add kid (Some id) map, KidSet.add kid seen - in - match Type_check.destruct_atom_nexp (env_of_annot ann) typ with - | Some (Nexp_aux (Nexp_var kid,l)) -> merge kid l - | _ -> - match Type_check.destruct_atom_bool (env_of_annot ann) typ with - | Some (NC_aux (NC_var kid,l)) -> merge kid l - | _ -> gone,map,KidSet.union seen (tyvars_of_typ typ) - end - | _, typ -> gone,map,KidSet.union seen (tyvars_of_typ typ) + | P_aux (P_typ (_,P_aux (P_id id, ann)),_), typ -> + tryon (Some id) (env_of_annot ann) typ + | P_aux (P_wild, ann), typ -> + tryon None (env_of_annot ann) typ + | _ -> (pat,typ)::acc,gone,map,KidSet.union seen (tyvars_of_typ typ) in - let gone,map,_ = List.fold_left try_eliminate (KidSet.empty, KBindings.empty, KidSet.empty) pats in - gone,map + let r_pats,gone,map,_ = List.fold_left try_eliminate ([],KidSet.empty, KBindings.empty, KidSet.empty) pats in + List.rev r_pats,gone,map let merge_var_patterns map pats = @@ -2671,7 +2717,7 @@ let doc_funcl mutrec rec_opt ?rec_set (FCL_aux(FCL_Funcl(id, pexp), annot)) = | _ -> demote_all_patterns env in let pats, binds = List.split (Util.list_mapi pattern_elim pats) in - let eliminated_kids, kid_to_arg_rename = merge_kids_atoms pats in + let pats, eliminated_kids, kid_to_arg_rename = merge_kids_atoms pats in let kid_to_arg_rename, pats = merge_var_patterns kid_to_arg_rename pats in let kids_used = KidSet.diff bound_kids eliminated_kids in let is_measured, recursive_ids = match rec_opt with @@ -2714,7 +2760,7 @@ let doc_funcl mutrec rec_opt ?rec_set (FCL_aux(FCL_Funcl(id, pexp), annot)) = in (* Put the constraints after pattern matching so that any type variable that's been replaced by one of the term-level arguments is bound. *) - let quantspp, constrspp = doc_typquant_items_separate ctxt braces tq in + let quantspp, constrspp = doc_typquant_items_separate ctxt env braces tq in let exp = List.fold_left (fun body f -> f body) (bind exp) binds in let used_a_pattern = ref false in let doc_binder (P_aux (p,ann) as pat, typ) = @@ -2727,9 +2773,10 @@ let doc_funcl mutrec rec_opt ?rec_set (FCL_aux(FCL_Funcl(id, pexp), annot)) = (* TODO: probably should provide partial environments to doc_typ *) match pat_is_plain_binder env pat with | Some id -> begin - match classify_ex_type ctxt env ~binding:id exp_typ with + let id_pp = match id with Some id -> doc_id id | None -> underscore in + match classify_ex_type ctxt env ?binding:id exp_typ with | ExNone, _, typ' -> - parens (separate space [doc_id id; colon; doc_typ ctxt Env.empty typ']) + parens (separate space [id_pp; colon; doc_typ ctxt Env.empty typ']) | ExGeneral, _, _ -> let full_typ = (expand_range_type exp_typ) in match destruct_exist_plain (Env.expand_synonyms env full_typ) with @@ -2738,17 +2785,22 @@ let doc_funcl mutrec rec_opt ?rec_set (FCL_aux(FCL_Funcl(id, pexp), annot)) = [A_aux (A_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_)) when Kid.compare (kopt_kid kopt) kid == 0 -> let coqty = if tyname = "atom" then "Z" else "bool" in - parens (separate space [doc_id id; colon; string coqty]) + parens (separate space [id_pp; colon; string coqty]) | Some ([kopt], nc, Typ_aux (Typ_app (Id_aux (Id ("atom" | "atom_bool"),_), [A_aux (A_nexp (Nexp_aux (Nexp_var kid,_)),_)]),_)) when Kid.compare (kopt_kid kopt) kid == 0 && not is_measured -> (used_a_pattern := true; - squote ^^ parens (separate space [string "existT"; underscore; doc_id id; underscore; colon; doc_typ ctxt Env.empty typ])) + squote ^^ parens (separate space [string "existT"; underscore; id_pp; underscore; colon; doc_typ ctxt Env.empty typ])) | _ -> - parens (separate space [doc_id id; colon; doc_typ ctxt Env.empty typ]) + parens (separate space [id_pp; colon; doc_typ ctxt Env.empty typ]) end | None -> + let typ = + match classify_ex_type ctxt env ~binding:id exp_typ with + | ExNone, _, typ' -> typ' + | ExGeneral, _, _ -> typ + in (used_a_pattern := true; squote ^^ parens (separate space [doc_pat ctxt true true (pat, exp_typ); colon; doc_typ ctxt Env.empty typ])) in @@ -2767,7 +2819,7 @@ let doc_funcl mutrec rec_opt ?rec_set (FCL_aux(FCL_Funcl(id, pexp), annot)) = let fixupspp = Util.map_filter (fun (pat,typ) -> match pat_is_plain_binder env pat with - | Some id -> begin + | Some (Some id) -> begin match destruct_exist_plain (Env.expand_synonyms env (expand_range_type typ)) with | Some (_, NC_aux (NC_true,_), _) -> None | Some ([KOpt_aux (KOpt_kind (_, kid), _)], nc, @@ -2777,7 +2829,7 @@ let doc_funcl mutrec rec_opt ?rec_set (FCL_aux(FCL_Funcl(id, pexp), annot)) = Some (string "let " ^^ doc_id id ^^ string " := projT1 " ^^ doc_id id ^^ string " in") | _ -> None end - | None -> None) pats + | _ -> None) pats in string "Fixpoint", [parens (string "_acc : Acc (Zwf 0) _reclimit")], @@ -2987,10 +3039,10 @@ let doc_axiom_typschm typ_env l (tqs,typ) = then string "M" ^^ space ^^ parens ret_typ_pp else ret_typ_pp in - let tyvars_pp, constrs_pp = doc_typquant_items_separate empty_ctxt braces tqs in + let tyvars_pp, constrs_pp = doc_typquant_items_separate empty_ctxt typ_env braces tqs in string "forall" ^/^ separate space tyvars_pp ^/^ arg_typs_pp ^/^ separate space constrs_pp ^^ comma ^/^ ret_typ_pp - | _ -> doc_typschm empty_ctxt true (TypSchm_aux (TypSchm_ts (tqs,typ),l)) + | _ -> doc_typschm empty_ctxt typ_env true (TypSchm_aux (TypSchm_ts (tqs,typ),l)) let doc_val_spec unimplemented (VS_aux (VS_val_spec(_,id,_,_),(l,ann)) as vs) = if !opt_undef_axioms && IdSet.mem id unimplemented then @@ -3045,7 +3097,7 @@ let rec doc_def unimplemented generic_eq_types def = | DEF_spec v_spec -> doc_val_spec unimplemented v_spec | DEF_fixity _ -> empty | DEF_overload _ -> empty - | DEF_type t_def -> group (doc_typdef generic_eq_types t_def) ^/^ hardline + | DEF_type t_def -> doc_typdef generic_eq_types t_def | DEF_reg_dec dec -> group (doc_dec dec) | DEF_default df -> empty diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml index 5cea9ba9..836c4fbc 100644 --- a/src/pretty_print_lem.ml +++ b/src/pretty_print_lem.ml @@ -557,6 +557,28 @@ let contains_early_return exp = { (Rewriter.compute_exp_alg false (||)) with e_return = (fun (_, r) -> (true, E_return r)); e_app = e_app } exp) +(* Does the expression have the form of a bitvector cast from the monomorphiser? *) +type is_bitvector_cast = BVC_yes | BVC_allowed | BVC_not +let is_bitvector_cast_out exp = + let merge x y = match x,y with + | BVC_allowed, _ -> y + | _, BVC_allowed -> x + | BVC_not, _ -> BVC_not + | _, BVC_not -> BVC_not + | _ -> BVC_yes + in + let rec aux (E_aux (e,_)) = + match e with + | E_tuple es -> List.fold_left merge BVC_allowed (List.map aux es) + | E_cast (_,e) -> aux e + | E_app (Id_aux (Id "bitvector_cast_out",_),_) -> BVC_yes + | E_id _ -> BVC_allowed + | _ -> BVC_not + in aux exp = BVC_yes + +let replace_env_for_cast_out new_env pat = + map_pat_annot (fun (l,a) -> (l,replace_env new_env a)) pat + let find_e_ids exp = let e_id id = IdSet.singleton id, E_id id in fst (fold_exp @@ -983,6 +1005,7 @@ let doc_exp_lem, doc_let_lem = else_pp and let_exp ctxt (LB_aux(lb,_)) = match lb with | LB_val(pat,e) -> + let pat = if is_bitvector_cast_out e then replace_env_for_cast_out ctxt.top_env pat else pat in prefix 2 1 (separate space [string "let"; doc_pat_lem ctxt true pat; equals]) (top_exp ctxt false e) diff --git a/src/pretty_print_sail.ml b/src/pretty_print_sail.ml index ae1f467c..5dbb6cd5 100644 --- a/src/pretty_print_sail.ml +++ b/src/pretty_print_sail.ml @@ -88,7 +88,7 @@ let rec doc_typ_pat (TP_aux (tpat_aux, _)) = | TP_var kid -> doc_kid kid | TP_app (f, tpats) -> doc_id f ^^ parens (separate_map (comma ^^ space) doc_typ_pat tpats) -let rec doc_nexp = +let rec doc_nexp nexp = let rec atomic_nexp (Nexp_aux (n_aux, _) as nexp) = match n_aux with | Nexp_constant c -> string (Big_int.to_string c) @@ -119,7 +119,7 @@ let rec doc_nexp = | Nexp_exp n -> separate space [string "2"; string "^"; atomic_nexp n] | _ -> atomic_nexp nexp in - nexp0 + nexp0 nexp let doc_effect (Effect_aux (aux, _)) = match aux with @@ -136,7 +136,9 @@ let rec doc_nc nc = | NC_equal (n1, n2) -> nc_op "==" n1 n2 | NC_not_equal (n1, n2) -> nc_op "!=" n1 n2 | NC_bounded_ge (n1, n2) -> nc_op ">=" n1 n2 + | NC_bounded_gt (n1, n2) -> nc_op ">" n1 n2 | NC_bounded_le (n1, n2) -> nc_op "<=" n1 n2 + | NC_bounded_lt (n1, n2) -> nc_op "<" n1 n2 | NC_set (kid, ints) -> separate space [doc_kid kid; string "in"; braces (separate_map (comma ^^ space) doc_int ints)] | NC_app (id, args) -> diff --git a/src/process_file.ml b/src/process_file.ml index 7da3c130..60261196 100644 --- a/src/process_file.ml +++ b/src/process_file.ml @@ -368,9 +368,9 @@ let output libpath out_arg files = output1 libpath out_arg f type_env defs) files -let rewrite_step n total env defs (name, rewriter) = +let rewrite_step n total (defs, env) (name, rewriter) = let t = Profile.start () in - let defs = rewriter env defs in + let defs, env = rewriter env defs in Profile.finish ("rewrite " ^ name) t; let _ = match !(opt_ddump_rewrite_ast) with | Some (f, i) -> @@ -384,15 +384,15 @@ let rewrite_step n total env defs (name, rewriter) = end | _ -> () in Util.progress "Rewrite " name n total; - defs + defs, env let rewrite env rewriters defs = let total = List.length rewriters in - try snd (List.fold_left (fun (n, defs) rw -> n + 1, rewrite_step n total env defs rw) (1, defs) rewriters) with + try snd (List.fold_left (fun (n, defsenv) rw -> n + 1, rewrite_step n total defsenv rw) (1, (defs, env)) rewriters) with | Type_check.Type_error (_, l, err) -> raise (Reporting.err_typ l (Type_error.string_of_type_error err)) -let rewrite_ast_initial env = rewrite env [("initial", fun _ -> Rewriter.rewrite_defs)] +let rewrite_ast_initial env = rewrite env [("initial", fun env defs -> Rewriter.rewrite_defs defs, env)] let rewrite_ast_target tgt env = rewrite env (Rewrites.rewrite_defs_target tgt) diff --git a/src/process_file.mli b/src/process_file.mli index e144727e..91cde014 100644 --- a/src/process_file.mli +++ b/src/process_file.mli @@ -56,9 +56,9 @@ val clear_symbols : unit -> unit val preprocess_ast : (Arg.key * Arg.spec * Arg.doc) list -> Parse_ast.defs -> Parse_ast.defs val check_ast : Type_check.Env.t -> unit Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t -val rewrite_ast_initial : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs -val rewrite_ast_target : string -> Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs -val rewrite_ast_check : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs +val rewrite_ast_initial : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t +val rewrite_ast_target : string -> Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t +val rewrite_ast_check : Type_check.Env.t -> Type_check.tannot Ast.defs -> Type_check.tannot Ast.defs * Type_check.Env.t val load_file_no_check : (Arg.key * Arg.spec * Arg.doc) list -> Ast.order -> string -> unit Ast.defs val load_file : (Arg.key * Arg.spec * Arg.doc) list -> Ast.order -> Type_check.Env.t -> string -> Type_check.tannot Ast.defs * Type_check.Env.t diff --git a/src/reporting.ml b/src/reporting.ml index 0b727836..e89ce396 100644 --- a/src/reporting.ml +++ b/src/reporting.ml @@ -180,10 +180,10 @@ let warn str1 l str2 = if !opt_warnings then match simp_loc l with | None -> - prerr_endline (Util.("Warning" |> yellow |> clear) ^ ": " ^ str1 ^ "\n" ^ str2) + prerr_endline (Util.("Warning" |> yellow |> clear) ^ ": " ^ str1 ^ "\n" ^ str2 ^ "\n") | Some (p1, p2) when not (StringSet.mem p1.pos_fname !ignored_files) -> prerr_endline (Util.("Warning" |> yellow |> clear) ^ ": " - ^ str1 ^ (if str1 <> "" then " " else "") ^ loc_to_string l ^ str2) + ^ str1 ^ (if str1 <> "" then " " else "") ^ loc_to_string l ^ str2 ^ "\n") | Some _ -> () else () diff --git a/src/rewrites.ml b/src/rewrites.ml index b1fb0cdd..0f747f59 100644 --- a/src/rewrites.ml +++ b/src/rewrites.ml @@ -771,23 +771,23 @@ and subsumes_fpat (FP_aux (FP_Fpat (id1,pat1),_)) (FP_aux (FP_Fpat (id2,pat2),_) (* A simple check for pattern disjointness; used for optimisation in the guarded pattern rewrite step *) -let rec disjoint_pat (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) = +let rec disjoint_pat env (P_aux (p1,annot1) as pat1) (P_aux (p2,annot2) as pat2) = match p1, p2 with - | P_as (pat1, _), _ -> disjoint_pat pat1 pat2 - | _, P_as (pat2, _) -> disjoint_pat pat1 pat2 - | P_typ (_, pat1), _ -> disjoint_pat pat1 pat2 - | _, P_typ (_, pat2) -> disjoint_pat pat1 pat2 - | P_var (pat1, _), _ -> disjoint_pat pat1 pat2 - | _, P_var (pat2, _) -> disjoint_pat pat1 pat2 - | P_id id, _ when id_is_unbound id (env_of_annot annot1) -> false - | _, P_id id when id_is_unbound id (env_of_annot annot2) -> false + | P_as (pat1, _), _ -> disjoint_pat env pat1 pat2 + | _, P_as (pat2, _) -> disjoint_pat env pat1 pat2 + | P_typ (_, pat1), _ -> disjoint_pat env pat1 pat2 + | _, P_typ (_, pat2) -> disjoint_pat env pat1 pat2 + | P_var (pat1, _), _ -> disjoint_pat env pat1 pat2 + | _, P_var (pat2, _) -> disjoint_pat env pat1 pat2 + | P_id id, _ when id_is_unbound id env -> false + | _, P_id id when id_is_unbound id env -> false | P_id id1, P_id id2 -> Id.compare id1 id2 <> 0 | P_app (id1, args1), P_app (id2, args2) -> - Id.compare id1 id2 <> 0 || List.exists2 disjoint_pat args1 args2 + Id.compare id1 id2 <> 0 || List.exists2 (disjoint_pat env) args1 args2 | P_vector pats1, P_vector pats2 | P_tup pats1, P_tup pats2 | P_list pats1, P_list pats2 -> - List.exists2 disjoint_pat pats1 pats2 + List.exists2 (disjoint_pat env) pats1 pats2 | _ -> false let equiv_pats pat1 pat2 = @@ -846,6 +846,8 @@ let case_exp e t cs = let env = env_of e in let annot = (get_loc_exp e, Some (env_of e, t, no_effect)) in match cs with + | [(P_aux (P_wild, _), body, _)] -> + fix_eff_exp body | [(P_aux (P_id id, pannot) as pat, body, _)] -> fix_eff_exp (annot_exp (E_let (LB_aux (LB_val (pat, e), pannot), body)) l env t) | _ -> @@ -873,7 +875,7 @@ let case_exp e t cs = strategy to ours: group *mutually exclusive* clauses, and try to merge them into a pattern match first instead of an if-then-else cascade. *) -let rewrite_guarded_clauses l cs = +let rewrite_guarded_clauses l env pat_typ typ cs = let rec group fallthrough clauses = let add_clause (pat,cls,annot) c = (pat,cls @ [c],annot) in let rec group_aux current acc = (function @@ -889,10 +891,12 @@ let rewrite_guarded_clauses l cs = let c' = (pat',guard',body',annot) in group_aux (add_clause current c') acc cs | None -> - let pat = remove_wildcards "g__" pat in + let pat = match cs with _::_ -> remove_wildcards "g__" pat | _ -> pat in group_aux (pat,[c],annot) (acc @ [current]) cs) | [] -> acc @ [current]) in let groups = match clauses with + | [(pat,guard,body,annot) as c] -> + [(pat, [c], annot)] | ((pat,guard,body,annot) as c) :: cs -> group_aux (remove_wildcards "g__" pat, [c], annot) [] cs | _ -> @@ -924,7 +928,7 @@ let rewrite_guarded_clauses l cs = (* For singleton clauses with a guard, use fallthrough clauses if the guard is not satisfied, but only those fallthrough clauses that are not disjoint with the current pattern *) - let overlapping_clause (pat, _, _) = not (disjoint_pat current_pat pat) in + let overlapping_clause (pat, _, _) = not (disjoint_pat env current_pat pat) in let fallthrough = List.filter overlapping_clause fallthrough in (match guard, fallthrough with | Some exp, _ :: _ -> @@ -934,7 +938,18 @@ let rewrite_guarded_clauses l cs = | [] -> raise (Reporting.err_unreachable l __POS__ "if_exp given empty list in rewrite_guarded_clauses")) in - group [] cs + let is_complete = Pattern_completeness.is_complete (Env.pattern_completeness_ctx env) (List.map construct_pexp cs) in + let fallthrough = + if not is_complete then + let p = P_aux (P_wild, (gen_loc l, mk_tannot env pat_typ no_effect)) in + let msg = "Pattern match failure at " ^ Reporting.short_loc_to_string l in + let a = mk_exp ~loc:(gen_loc l) (E_assert (mk_lit_exp L_false, mk_lit_exp (L_string msg))) in + let b = mk_exp ~loc:(gen_loc l) (E_exit (mk_lit_exp L_unit)) in + let (E_aux (_, (_, ann)) as e) = check_exp env (mk_exp ~loc:(gen_loc l) (E_block [a; b])) typ in + [(p,None,e,(gen_loc l,ann))] + else [] + in + group [] (cs @ fallthrough) let bitwise_and_exp exp1 exp2 = let (E_aux (_,(l,_))) = exp1 in @@ -1316,7 +1331,7 @@ let rewrite_exp_guarded_pats rewriters (E_aux (exp,(l,annot)) as full_exp) = (pat, None, rewrite_rec body, annot) | Pat_aux (Pat_when (pat, guard, body), annot) -> (pat, Some (rewrite_rec guard), rewrite_rec body, annot) in - let clauses = rewrite_guarded_clauses l (List.map clause ps) in + let clauses = rewrite_guarded_clauses l (env_of full_exp) (typ_of e) (typ_of full_exp) (List.map clause ps) in let e = rewrite_rec e in if (effectful e) then let (E_aux (_,(el,eannot))) = e in @@ -1334,7 +1349,7 @@ let rewrite_exp_guarded_pats rewriters (E_aux (exp,(l,annot)) as full_exp) = (pat, None, rewrite_rec body, annot) | Pat_aux (Pat_when (pat, guard, body), annot) -> (pat, Some (rewrite_rec guard), rewrite_rec body, annot) in - let clauses = rewrite_guarded_clauses l (List.map clause ps) in + let clauses = rewrite_guarded_clauses l (env_of full_exp) (typ_of e) (typ_of full_exp) (List.map clause ps) in let pexp (pat,body,annot) = Pat_aux (Pat_exp (pat,body),annot) in let ps = List.map pexp clauses in fix_eff_exp (annot_exp (E_try (e,ps)) l (env_of full_exp) (typ_of full_exp)) @@ -1342,12 +1357,21 @@ let rewrite_exp_guarded_pats rewriters (E_aux (exp,(l,annot)) as full_exp) = let rewrite_fun_guarded_pats rewriters (FD_aux (FD_function (r,t,e,funcls),(l,fdannot))) = let funcls = match funcls with - | (FCL_aux (FCL_Funcl(id,_),_) :: _) -> + | (FCL_aux (FCL_Funcl(id,pexp), fclannot) :: _) -> let clause (FCL_aux (FCL_Funcl(_,pexp),annot)) = let pat,guard,exp,_ = destruct_pexp pexp in let exp = rewriters.rewrite_exp rewriters exp in (pat,guard,exp,annot) in - let cs = rewrite_guarded_clauses l (List.map clause funcls) in + let pexp_pat_typ, pexp_ret_typ = + let pat, _, exp, _ = destruct_pexp pexp in + (typ_of_pat pat, typ_of exp) + in + let pat_typ, ret_typ = match Env.get_val_spec_orig id (env_of_annot fclannot) with + | (tq, Typ_aux (Typ_fn ([arg_typ], ret_typ, _), _)) -> (arg_typ, ret_typ) + | (tq, Typ_aux (Typ_fn (arg_typs, ret_typ, _), _)) -> (tuple_typ arg_typs, ret_typ) + | _ -> (pexp_pat_typ, pexp_ret_typ) | exception _ -> (pexp_pat_typ, pexp_ret_typ) + in + let cs = rewrite_guarded_clauses l (env_of_annot fclannot) pat_typ ret_typ (List.map clause funcls) in List.map (fun (pat,exp,annot) -> FCL_aux (FCL_Funcl(id,construct_pexp (pat,None,exp,(Parse_ast.Unknown,empty_tannot))),annot)) cs | _ -> funcls (* TODO is the empty list possible here? *) in @@ -1440,7 +1464,7 @@ let rewrite_defs_exp_lift_assign env defs = rewrite_defs_base write_reg_ref (vector_access (GPR, i)) exp *) let rewrite_register_ref_writes (Defs defs) = - let (Defs write_reg_spec) = fst (Type_error.check Env.empty (Defs (List.map gen_vs + let (Defs write_reg_spec) = fst (Type_error.check initial_env (Defs (List.map gen_vs [("write_reg_ref", "forall ('a : Type). (register('a), 'a) -> unit effect {wreg}")]))) in let lexp_ref_exp (LEXP_aux (_, annot) as lexp) = try @@ -1630,7 +1654,7 @@ let rewrite_defs_early_return env (Defs defs) = FD_aux (FD_function (rec_opt, tannot_opt, effect_opt, List.map (rewrite_funcl_early_return rewriters) funcls), a) in - let (Defs early_ret_spec) = fst (Type_error.check Env.empty (Defs [gen_vs + let (Defs early_ret_spec) = fst (Type_error.check initial_env (Defs [gen_vs ("early_return", "forall ('a : Type) ('b : Type). 'a -> 'b effect {escape}")])) in rewrite_defs_base @@ -1870,10 +1894,9 @@ let rewrite_fix_val_specs env (Defs defs) = Rec_aux (Rec_rec, Parse_ast.Unknown) | _ -> recopt in - let tannotopt = match tannotopt, funcls with - | Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), l), - FCL_aux (FCL_Funcl (_, Pat_aux ((Pat_exp (_, exp) | Pat_when (_, _, exp)), _)), _) :: _ -> - Typ_annot_opt_aux (Typ_annot_opt_some (typq, Env.expand_synonyms (env_of exp) typ), l) + let tannotopt = match tannotopt with + | Typ_annot_opt_aux (Typ_annot_opt_some (typq, typ), l) -> + Typ_annot_opt_aux (Typ_annot_opt_none, l) | _ -> tannotopt in (val_specs, FD_aux (FD_function (recopt, tannotopt, effopt, funcls), a)) in @@ -2459,7 +2482,8 @@ let rewrite_defs_letbind_effects env = rewrap (E_var (lexp,exp1,n_exp exp2 k)))) | E_internal_return exp1 -> n_exp_name exp1 (fun exp1 -> - k (if effectful (propagate_exp_effect exp1) then exp1 else rewrap (E_internal_return exp1))) + let exp1 = fix_eff_exp (propagate_exp_effect exp1) in + k (if effectful exp1 then exp1 else rewrap (E_internal_return exp1))) | E_internal_value v -> k (rewrap (E_internal_value v)) | E_return exp' -> @@ -3628,7 +3652,8 @@ let remove_reference_types exp = let rewrite_defs_remove_superfluous_letbinds env = let e_aux (exp,annot) = match exp with - | E_let (LB_aux (LB_val (pat, exp1), _), exp2) -> + | E_let (LB_aux (LB_val (pat, exp1), _), exp2) + | E_internal_plet (pat, exp1, exp2) -> begin match untyp_pat pat, uncast_exp exp1, uncast_exp exp2 with (* 'let x = EXP1 in x' can be replaced with 'EXP1' *) | (P_aux (P_id id, _), _), _, (E_aux (E_id id', _), _) @@ -3640,22 +3665,22 @@ let rewrite_defs_remove_superfluous_letbinds env = (* "let x = EXP1 in return x" can be replaced with 'return (EXP1)', at least when EXP1 is 'small' enough *) | (P_aux (P_id id, _), _), _, (E_aux (E_internal_return (E_aux (E_id id', _)), _), _) - when Id.compare id id' = 0 && small exp1 -> + when Id.compare id id' = 0 && small exp1 && not (effectful exp1) -> let (E_aux (_,e1annot)) = exp1 in E_aux (E_internal_return (exp1),e1annot) + | _, (E_aux (E_throw e, a), _), _ -> E_aux (E_throw e, a) + | (pat, _), (E_aux (E_assert (c, msg), a) as assert_exp, _), _ -> + begin match typ_of c with + | Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool nc, _)]), _) + when prove __POS__ (env_of c) (nc_not nc) -> + (* Drop rest of block after an 'assert(false)' *) + let exit_exp = E_aux (E_exit (infer_exp (env_of c) (mk_lit_exp L_unit)), a) in + E_aux (E_internal_plet (pat, assert_exp, exit_exp), annot) + | _ -> + E_aux (exp, annot) + end | _ -> E_aux (exp,annot) end - | E_internal_plet (_, E_aux (E_throw e, a), _) -> E_aux (E_throw e, a) - | E_internal_plet (pat, (E_aux (E_assert (c, msg), a) as assert_exp), _) -> - begin match typ_of c with - | Typ_aux (Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool nc, _)]), _) - when prove __POS__ (env_of c) (nc_not nc) -> - (* Drop rest of block after an 'assert(false)' *) - let exit_exp = E_aux (E_exit (infer_exp (env_of c) (mk_lit_exp L_unit)), a) in - E_aux (E_internal_plet (pat, assert_exp, exit_exp), annot) - | _ -> - E_aux (exp, annot) - end | _ -> E_aux (exp,annot) in let alg = { id_exp_alg with e_aux = e_aux } in @@ -3737,7 +3762,7 @@ let rewrite_defs_remove_superfluous_returns env = when lit = lit' -> add_opt_cast ptyp etyp a exp1 | (P_aux (P_wild,pannot), ptyp), - (E_aux (E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)), a), etyp) + (E_aux ((E_internal_return (E_aux (E_lit (L_aux (L_unit,_)),_)) | E_lit (L_aux (L_unit,_))), a), etyp) when is_unit_typ (typ_of exp1) -> add_opt_cast ptyp etyp a exp1 | (P_aux (P_id id,_), ptyp), @@ -3773,7 +3798,7 @@ let rewrite_defs_remove_superfluous_returns env = let rewrite_defs_remove_e_assign env (Defs defs) = - let (Defs loop_specs) = fst (Type_error.check Env.empty (Defs (List.map gen_vs + let (Defs loop_specs) = fst (Type_error.check initial_env (Defs (List.map gen_vs [("foreach#", "forall ('vars : Type). (int, int, int, bool, 'vars, 'vars) -> 'vars"); ("while#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); ("until#", "forall ('vars : Type). (bool, 'vars, 'vars) -> 'vars"); @@ -4660,11 +4685,11 @@ let rewrite_loops_with_escape_effect env defs = in rewrite_defs_base { rewriters_base with rewrite_exp } defs -let recheck_defs env defs = fst (Type_error.check initial_env defs) +let recheck_defs env defs = Type_error.check initial_env defs let recheck_defs_without_effects env defs = let old = !opt_no_effects in let () = opt_no_effects := true in - let result,_ = Type_error.check initial_env defs in + let result = Type_error.check initial_env defs in let () = opt_no_effects := old in result @@ -4749,9 +4774,10 @@ let opt_auto_mono = ref false let opt_dall_split_errors = ref false let opt_dmono_continue = ref false -let monomorphise env defs = +let monomorphise target env defs = let open Monomorphise in monomorphise + target { auto = !opt_auto_mono; debug_analysis = !opt_dmono_analysis; all_split_errors = !opt_dall_split_errors; @@ -4764,12 +4790,20 @@ let if_mono f env defs = | [], false -> defs | _, _ -> f env defs +let if_mono_env f env defs = + match !opt_mono_split, !opt_auto_mono with + | [], false -> defs, env + | _, _ -> f env defs + (* Also turn mwords stages on when we're just trying out mono *) let if_mwords f env defs = if !Pretty_print_lem.opt_mwords then f env defs else if_mono f env defs +let if_mwords_env f env defs = + if !Pretty_print_lem.opt_mwords then f env defs else if_mono_env f env defs type rewriter = | Basic_rewriter of (Env.t -> tannot defs -> tannot defs) + | Checking_rewriter of (Env.t -> tannot defs -> tannot defs * Env.t) | Bool_rewriter of (bool -> rewriter) | String_rewriter of (string -> rewriter) | Literal_rewriter of ((lit -> bool) -> rewriter) @@ -4793,6 +4827,8 @@ let instantiate_rewrite rewriter args = match rewriter, arg with | Basic_rewriter rw, If_mono_arg -> Basic_rewriter (if_mono rw) | Basic_rewriter rw, If_mwords_arg -> Basic_rewriter (if_mwords rw) + | Checking_rewriter rw, If_mono_arg -> Checking_rewriter (if_mono_env rw) + | Checking_rewriter rw, If_mwords_arg -> Checking_rewriter (if_mwords_env rw) | Bool_rewriter rw, Bool_arg b -> rw b | String_rewriter rw, String_arg str -> rw str | Literal_rewriter rw, Literal_arg selector -> rw (selector_function selector) @@ -4800,14 +4836,15 @@ let instantiate_rewrite rewriter args = raise (Reporting.err_unreachable Parse_ast.Unknown __POS__ "Invalid rewrite argument") in match List.fold_left instantiate rewriter args with - | Basic_rewriter rw -> rw + | Basic_rewriter rw -> fun env defs -> rw env defs, env + | Checking_rewriter rw -> rw | _ -> raise (Reporting.err_general Parse_ast.Unknown "Rewrite not fully instantiated") let all_rewrites = [ ("no_effect_check", Basic_rewriter (fun _ defs -> opt_no_effects := true; defs)); - ("recheck_defs", Basic_rewriter recheck_defs); - ("recheck_defs_without_effects", Basic_rewriter recheck_defs_without_effects); + ("recheck_defs", Checking_rewriter recheck_defs); + ("recheck_defs_without_effects", Checking_rewriter recheck_defs_without_effects); ("optimize_recheck_defs", Basic_rewriter (fun _ -> Optimize.recheck)); ("realise_mappings", Basic_rewriter rewrite_defs_realise_mappings); ("remove_mapping_valspecs", Basic_rewriter remove_mapping_valspecs); @@ -4816,12 +4853,12 @@ let all_rewrites = [ ("mapping_builtins", Basic_rewriter rewrite_defs_mapping_patterns); ("mono_rewrites", Basic_rewriter mono_rewrites); ("toplevel_nexps", Basic_rewriter rewrite_toplevel_nexps); - ("monomorphise", Basic_rewriter monomorphise); + ("monomorphise", String_rewriter (fun target -> Basic_rewriter (monomorphise target))); ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons)); ("add_bitvector_casts", Basic_rewriter (fun _ -> Monomorphise.add_bitvector_casts)); ("atoms_to_singletons", Basic_rewriter (fun _ -> Monomorphise.rewrite_atoms_to_singletons)); ("remove_impossible_int_cases", Basic_rewriter Constant_propagation.remove_impossible_int_cases); - ("const_prop_mutrec", Basic_rewriter Constant_propagation_mutrec.rewrite_defs); + ("const_prop_mutrec", String_rewriter (fun target -> Basic_rewriter (Constant_propagation_mutrec.rewrite_defs target))); ("make_cases_exhaustive", Basic_rewriter MakeExhaustive.rewrite); ("undefined", Bool_rewriter (fun b -> Basic_rewriter (rewrite_undefined_if_gen b))); ("vector_string_pats_to_bit_list", Basic_rewriter rewrite_defs_vector_string_pats_to_bit_list); @@ -4853,7 +4890,7 @@ let all_rewrites = [ ("simple_types", Basic_rewriter rewrite_simple_types); ("overload_cast", Basic_rewriter rewrite_overload_cast); ("top_sort_defs", Basic_rewriter (fun _ -> top_sort_defs)); - ("constant_fold", Basic_rewriter (fun _ -> Constant_fold.rewrite_constant_function_calls)); + ("constant_fold", String_rewriter (fun target -> Basic_rewriter (fun _ -> Constant_fold.rewrite_constant_function_calls target))); ("split", String_rewriter (fun str -> Basic_rewriter (rewrite_split_fun_ctor_pats str))); ("properties", Basic_rewriter (fun _ -> Property.rewrite)); ] @@ -4868,7 +4905,7 @@ let rewrites_lem = [ ("recheck_defs", [If_mono_arg]); ("undefined", [Bool_arg false]); ("toplevel_nexps", [If_mono_arg]); - ("monomorphise", [If_mono_arg]); + ("monomorphise", [String_arg "lem"; If_mono_arg]); ("recheck_defs", [If_mwords_arg]); ("add_bitvector_casts", [If_mwords_arg]); ("atoms_to_singletons", [If_mono_arg]); @@ -4891,7 +4928,7 @@ let rewrites_lem = [ ("split", [String_arg "execute"]); ("recheck_defs", []); ("top_sort_defs", []); - ("const_prop_mutrec", []); + ("const_prop_mutrec", [String_arg "lem"]); ("vector_string_pats_to_bit_list", []); ("exp_lift_assign", []); ("early_return", []); @@ -4987,7 +5024,7 @@ let rewrites_c = [ ("mono_rewrites", [If_mono_arg]); ("recheck_defs", [If_mono_arg]); ("toplevel_nexps", [If_mono_arg]); - ("monomorphise", [If_mono_arg]); + ("monomorphise", [String_arg "c"; If_mono_arg]); ("atoms_to_singletons", [If_mono_arg]); ("recheck_defs", [If_mono_arg]); ("undefined", [Bool_arg false]); @@ -5002,7 +5039,7 @@ let rewrites_c = [ ("exp_lift_assign", []); ("merge_function_clauses", []); ("optimize_recheck_defs", []); - ("constant_fold", []) + ("constant_fold", [String_arg "c"]) ] let rewrites_interpreter = [ @@ -5025,7 +5062,6 @@ let rewrites_target tgt = | "c" -> rewrites_c | "ir" -> rewrites_c @ [("properties", [])] | "smt" -> rewrites_c @ [("properties", [])] - | "smtfuzz" -> rewrites_c @ [("properties", [])] | "sail" -> [] | "latex" -> [] | "interpreter" -> rewrites_interpreter @@ -5065,5 +5101,5 @@ let rewrite_check_annot = rewrite_pat = (fun _ -> check_pat) } let rewrite_defs_check = [ - ("check_annotations", fun _ -> rewrite_check_annot); + ("check_annotations", fun env defs -> rewrite_check_annot defs, env); ] diff --git a/src/rewrites.mli b/src/rewrites.mli index e30a4206..3b572d51 100644 --- a/src/rewrites.mli +++ b/src/rewrites.mli @@ -70,10 +70,11 @@ val move_loop_measures : 'a defs -> 'a defs val rewrite_undefined : bool -> Env.t -> tannot defs -> tannot defs (* Perform rewrites to create an AST supported for a specific target *) -val rewrite_defs_target : string -> (string * (Env.t -> tannot defs -> tannot defs)) list +val rewrite_defs_target : string -> (string * (Env.t -> tannot defs -> tannot defs * Env.t)) list type rewriter = | Basic_rewriter of (Env.t -> tannot defs -> tannot defs) + | Checking_rewriter of (Env.t -> tannot defs -> tannot defs * Env.t) | Bool_rewriter of (bool -> rewriter) | String_rewriter of (string -> rewriter) | Literal_rewriter of ((lit -> bool) -> rewriter) @@ -96,6 +97,6 @@ val opt_coq_warn_nonexhaustive : bool ref (* This is a special rewriter pass that checks AST invariants without actually doing any re-writing *) -val rewrite_defs_check : (string * (Env.t -> tannot defs -> tannot defs)) list +val rewrite_defs_check : (string * (Env.t -> tannot defs -> tannot defs * Env.t)) list val simple_typ : typ -> typ diff --git a/src/sail.ml b/src/sail.ml index eae7c4cf..e9b1914d 100644 --- a/src/sail.ml +++ b/src/sail.ml @@ -62,11 +62,14 @@ let opt_memo_z3 = ref false let opt_sanity = ref false let opt_includes_c = ref ([]:string list) let opt_specialize_c = ref false +let opt_smt_serialize = ref false +let opt_smt_fuzz = ref false let opt_libs_lem = ref ([]:string list) let opt_libs_coq = ref ([]:string list) let opt_file_arguments = ref ([]:string list) let opt_process_elf : string option ref = ref None let opt_ocaml_generators = ref ([]:string list) +let opt_splice = ref ([]:string list) let set_target name = Arg.Unit (fun _ -> opt_target := Some name) @@ -161,6 +164,9 @@ let options = Arg.align ([ ( "-smt_vector_size", Arg.String (fun n -> Jib_smt.opt_default_vector_index := int_of_string n), "<n> set a bound of 2 ^ n for generic vectors in generated SMT (default 5)"); + ( "-smt_serialize", + Arg.Tuple [set_target "smt"; Arg.Set opt_smt_serialize], + " compile Sail to IR suitable for sail-axiomatic tool"); ( "-c", Arg.Tuple [set_target "c"; Arg.Set Initial_check.opt_undefined_gen], " output a C translated version of the input"); @@ -270,6 +276,9 @@ let options = Arg.align ([ ( "-memo", Arg.Tuple [Arg.Set opt_memo_z3; Arg.Set C_backend.opt_memo_cache], " memoize calls to z3, and intermediate compilation results"); + ( "-splice", + Arg.String (fun s -> opt_splice := s :: !opt_splice), + "<filename> add functions from file, replacing existing definitions where necessary"); ( "-undefined_gen", Arg.Set Initial_check.opt_undefined_gen, " generate undefined_type functions for types in the specification"); @@ -349,7 +358,7 @@ let options = Arg.align ([ Arg.Set Profile.opt_profile, " (debug) provide basic profiling information for rewriting passes within Sail"); ( "-dsmtfuzz", - set_target "smtfuzz", + Arg.Tuple [set_target "smt"; Arg.Set opt_smt_fuzz], " (debug) fuzz sail SMT builtins"); ( "-v", Arg.Set opt_print_version, @@ -400,7 +409,7 @@ let load_files ?check:(check=false) type_envs files = ("out.sail", ast, type_envs) else let ast = Scattered.descatter ast in - let ast = rewrite_ast_initial type_envs ast in + let ast, type_envs = rewrite_ast_initial type_envs ast in let out_name = match !opt_file_out with | None when parsed = [] -> "out.sail" @@ -482,9 +491,21 @@ let target name out_name ast type_envs = flush output_chan; if close then close_out output_chan else () - | Some "smtfuzz" -> + | Some "smt" when !opt_smt_fuzz -> Jib_smt_fuzz.fuzz 0 type_envs ast + | Some "smt" when !opt_smt_serialize -> + let ast_smt, type_envs = Specialize.(specialize typ_ord_specialization type_envs ast) in + let ast_smt, type_envs = Specialize.(specialize_passes 2 int_specialization_with_externs type_envs ast_smt) in + let jib, ctx = Jib_smt.compile type_envs ast_smt in + let name_file = + match !opt_file_out with + | Some f -> f ^ ".smt_model" + | None -> "sail.smt_model" + in + Reporting.opt_warnings := true; + Jib_smt.serialize_smt_model name_file type_envs ast_smt + | Some "smt" -> let open Ast_util in let props = Property.find_properties ast in @@ -530,6 +551,10 @@ let main () = else begin let out_name, ast, type_envs = load_files Type_check.initial_env !opt_file_arguments in + let ast, type_envs = + List.fold_right (fun file (ast,_) -> Splice.splice ast file) + (!opt_splice) (ast, type_envs) + in Reporting.opt_warnings := false; (* Don't show warnings during re-writing for now *) begin match !opt_process_elf, !opt_file_out with @@ -554,7 +579,7 @@ let main () = else (); let type_envs, ast = prover_regstate !opt_target ast type_envs in - let ast = match !opt_target with Some tgt -> rewrite_ast_target tgt type_envs ast | None -> ast in + let ast, type_envs = match !opt_target with Some tgt -> rewrite_ast_target tgt type_envs ast | None -> ast, type_envs in target !opt_target out_name ast type_envs; if !Interactive.opt_interactive then diff --git a/src/sail_lib.ml b/src/sail_lib.ml index 3812e4f7..164bcefa 100644 --- a/src/sail_lib.ml +++ b/src/sail_lib.ml @@ -143,6 +143,12 @@ let rec take n xs = | n, (x :: xs) -> x :: take (n - 1) xs | n, [] -> [] +let count_leading_zeros xs = + let rec aux bs acc = match bs with + | (B0 :: bs') -> aux bs' (acc + 1) + | _ -> acc in + Big_int.of_int (aux xs 0) + let subrange (list, n, m) = let n = Big_int.to_int n in let m = Big_int.to_int m in @@ -746,6 +752,11 @@ let shiftr (x, y) = let rbits = zeros @ x in take (List.length x) rbits +let arith_shiftr (x, y) = + let msbs = replicate_bits (take 1 x, y) in + let rbits = msbs @ x in + take (List.length x) rbits + let shift_bits_right (x, y) = shiftr (x, uint(y)) diff --git a/src/slice.ml b/src/slice.ml index 9dee4761..1bbbca1e 100644 --- a/src/slice.ml +++ b/src/slice.ml @@ -104,7 +104,7 @@ let builtins = let rec constraint_ids' (NC_aux (aux, _)) = match aux with - | NC_equal (n1, n2) | NC_bounded_le (n1, n2) | NC_bounded_ge (n1, n2) | NC_not_equal (n1, n2) -> + | NC_equal (n1, n2) | NC_bounded_le (n1, n2) | NC_bounded_ge (n1, n2) | NC_bounded_lt (n1, n2) | NC_bounded_gt (n1, n2) | NC_not_equal (n1, n2) -> IdSet.union (nexp_ids' n1) (nexp_ids' n2) | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> IdSet.union (constraint_ids' nc1) (constraint_ids' nc2) diff --git a/src/spec_analysis.ml b/src/spec_analysis.ml index 907c62a5..57e415a8 100644 --- a/src/spec_analysis.ml +++ b/src/spec_analysis.ml @@ -186,7 +186,7 @@ let rec fv_of_exp consider_var bound used set (E_aux (e,(_,tannot))) : (Nameset. match e with | E_block es | Ast.E_tuple es | Ast.E_vector es | Ast.E_list es -> list_fv bound used set es - | E_id id -> + | E_id id | E_ref id -> let used = conditional_add_exp bound used id in let used = Nameset.union (free_type_names_tannot consider_var tannot) used in bound,used,set diff --git a/src/specialize.ml b/src/specialize.ml index 815514d1..d749bc53 100644 --- a/src/specialize.ml +++ b/src/specialize.ml @@ -186,7 +186,9 @@ let string_of_instantiation instantiation = | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2 | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 | NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2 + | NC_aux (NC_bounded_gt (n1, n2), _) -> string_of_nexp n1 ^ " > " ^ string_of_nexp n2 | NC_aux (NC_bounded_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2 + | NC_aux (NC_bounded_lt (n1, n2), _) -> string_of_nexp n1 ^ " < " ^ string_of_nexp n2 | NC_aux (NC_or (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_and (nc1, nc2), _) -> diff --git a/src/splice.ml b/src/splice.ml new file mode 100644 index 00000000..90488c0a --- /dev/null +++ b/src/splice.ml @@ -0,0 +1,52 @@ +(* Currently limited to: + - functions, no scattered, no preprocessor + - no new undefined functions (but no explicit check here yet) +*) + +open Ast +open Ast_util + +let scan_defs (Defs defs) = + let scan (ids, specs) = function + | DEF_fundef fd -> + IdSet.add (id_of_fundef fd) ids, specs + | DEF_spec (VS_aux (VS_val_spec (_,id,_,_),_) as vs) -> + ids, Bindings.add id vs specs + | d -> raise (Reporting.err_general (def_loc d) + "Definition in splice file isn't a spec or function") + in List.fold_left scan (IdSet.empty, Bindings.empty) defs + +let filter_old_ast repl_ids repl_specs (Defs defs) = + let check (rdefs,spec_found) def = + match def with + | DEF_fundef fd -> + let id = id_of_fundef fd in + if IdSet.mem id repl_ids + then rdefs, spec_found + else def::rdefs, spec_found + | DEF_spec (VS_aux (VS_val_spec (_,id,_,_),_)) -> + (match Bindings.find_opt id repl_specs with + | Some vs -> DEF_spec vs :: rdefs, IdSet.add id spec_found + | None -> def::rdefs, spec_found) + | _ -> def::rdefs, spec_found + in + let rdefs, spec_found = List.fold_left check ([],IdSet.empty) defs in + (List.rev rdefs, spec_found) + +let filter_replacements spec_found (Defs defs) = + let not_found = function + | DEF_spec (VS_aux (VS_val_spec (_,id,_,_),_)) -> not (IdSet.mem id spec_found) + | _ -> true + in List.filter not_found defs + +let splice ast file = + let parsed_ast = Process_file.parse_file file in + let repl_ast = Initial_check.process_ast ~generate:false parsed_ast in + let repl_ast = Rewrites.move_loop_measures repl_ast in + let repl_ast = map_defs_annot (fun (l,_) -> l,Type_check.empty_tannot) repl_ast in + let repl_ids, repl_specs = scan_defs repl_ast in + let defs1, specs_found = filter_old_ast repl_ids repl_specs ast in + let defs2 = filter_replacements specs_found repl_ast in + let new_ast = Defs (defs1 @ defs2) in + Type_error.check Type_check.initial_env new_ast + diff --git a/src/toFromInterp_backend.ml b/src/toFromInterp_backend.ml index fad45412..49739c30 100644 --- a/src/toFromInterp_backend.ml +++ b/src/toFromInterp_backend.ml @@ -95,7 +95,7 @@ let frominterp_typedef (TD_aux (td_aux, (l, _))) = | _ -> string ("NEXP(" ^ string_of_nexp nexp ^ ")") in let rec fromValueTypArg (A_aux (a_aux, _)) = match a_aux with - | A_typ typ -> fromValueTyp typ "" + | A_typ typ -> parens ((string "fun v -> ") ^^ parens (fromValueTyp typ "v")) | A_nexp nexp -> fromValueNexp nexp | A_order order -> string ("Order_" ^ (string_of_order order)) | A_bool _ -> parens (string "boolFromInterpValue") @@ -250,7 +250,7 @@ let tointerp_typedef (TD_aux (td_aux, (l, _))) = | _ -> string ("NEXP(" ^ string_of_nexp nexp ^ ")") in let rec toValueTypArg (A_aux (a_aux, _)) = match a_aux with - | A_typ typ -> toValueTyp typ "" + | A_typ typ -> parens ((string "fun v -> ") ^^ parens (toValueTyp typ "v")) | A_nexp nexp -> toValueNexp nexp | A_order order -> string ("Order_" ^ (string_of_order order)) | A_bool _ -> parens (string "boolToInterpValue") diff --git a/src/type_check.ml b/src/type_check.ml index 1dd806f0..fb98ee1b 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -202,7 +202,9 @@ and strip_nexp = function and strip_n_constraint_aux = function | NC_equal (nexp1, nexp2) -> NC_equal (strip_nexp nexp1, strip_nexp nexp2) | NC_bounded_ge (nexp1, nexp2) -> NC_bounded_ge (strip_nexp nexp1, strip_nexp nexp2) + | NC_bounded_gt (nexp1, nexp2) -> NC_bounded_gt (strip_nexp nexp1, strip_nexp nexp2) | NC_bounded_le (nexp1, nexp2) -> NC_bounded_le (strip_nexp nexp1, strip_nexp nexp2) + | NC_bounded_lt (nexp1, nexp2) -> NC_bounded_lt (strip_nexp nexp1, strip_nexp nexp2) | NC_not_equal (nexp1, nexp2) -> NC_not_equal (strip_nexp nexp1, strip_nexp nexp2) | NC_set (kid, nums) -> NC_set (strip_kid kid, nums) | NC_or (nc1, nc2) -> NC_or (strip_n_constraint nc1, strip_n_constraint nc2) @@ -294,7 +296,7 @@ and typ_arg_nexps (A_aux (typ_arg_aux, l)) = | A_order ord -> [] and constraint_nexps (NC_aux (nc_aux, l)) = match nc_aux with - | NC_equal (n1, n2) | NC_bounded_ge (n1, n2) | NC_bounded_le (n1, n2) | NC_not_equal (n1, n2) -> + | NC_equal (n1, n2) | NC_bounded_ge (n1, n2) | NC_bounded_le (n1, n2) | NC_bounded_gt (n1, n2) | NC_bounded_lt (n1, n2) | NC_not_equal (n1, n2) -> [n1; n2] | NC_set _ | NC_true | NC_false | NC_var _ -> [] | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> constraint_nexps nc1 @ constraint_nexps nc2 @@ -449,6 +451,7 @@ module Env : sig val set_default_order_dec : t -> t val add_enum : id -> id list -> t -> t val get_enum : id -> t -> id list + val is_enum : id -> t -> bool val get_casts : t -> id list val allow_casts : t -> bool val no_casts : t -> t @@ -669,7 +672,9 @@ end = struct | NC_equal (n1, n2) -> NC_aux (NC_equal (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) | NC_not_equal (n1, n2) -> NC_aux (NC_not_equal (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) | NC_bounded_le (n1, n2) -> NC_aux (NC_bounded_le (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) + | NC_bounded_lt (n1, n2) -> NC_aux (NC_bounded_lt (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) | NC_bounded_ge (n1, n2) -> NC_aux (NC_bounded_ge (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) + | NC_bounded_gt (n1, n2) -> NC_aux (NC_bounded_gt (expand_nexp_synonyms env n1, expand_nexp_synonyms env n2), l) | NC_app (id, args) -> (try begin match get_typ_synonym id env l env args with @@ -826,10 +831,11 @@ end = struct | A_typ typ -> wf_typ ~exs:exs env typ | A_order ord -> wf_order env ord | A_bool nc -> wf_constraint ~exs:exs env nc - and wf_nexp ?exs:(exs=KidSet.empty) env (Nexp_aux (nexp_aux, l) as nexp) = + and wf_nexp ?exs:(exs=KidSet.empty) env nexp = wf_debug "nexp" string_of_nexp nexp exs; + let Nexp_aux (nexp_aux, l) = expand_nexp_synonyms env nexp in match nexp_aux with - | Nexp_id _ -> () + | Nexp_id id -> typ_error env l ("Undefined synonym " ^ string_of_id id) | Nexp_var kid when KidSet.mem kid exs -> () | Nexp_var kid -> begin match get_typ_var kid env with @@ -862,7 +868,9 @@ end = struct | NC_equal (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 | NC_not_equal (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 | NC_bounded_ge (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 + | NC_bounded_gt (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 | NC_bounded_le (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 + | NC_bounded_lt (n1, n2) -> wf_nexp ~exs:exs env n1; wf_nexp ~exs:exs env n2 | NC_set (kid, _) when KidSet.mem kid exs -> () | NC_set (kid, _) -> begin match get_typ_var kid env with @@ -1050,6 +1058,8 @@ end = struct with | Not_found -> typ_error env (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist") + let is_enum id env = Bindings.mem id env.enums + let is_record id env = Bindings.mem id env.records let get_record id env = Bindings.find id env.records @@ -1325,6 +1335,10 @@ let add_typquant l (quant : typquant) (env : Env.t) : Env.t = let expand_bind_synonyms l env (typq, typ) = typq, Env.expand_synonyms (add_typquant l typq env) typ +let wf_typschm env (TypSchm_aux (TypSchm_ts (typq, typ), l)) = + let env = add_typquant l typq env in + Env.wf_typ env typ + (* Create vectors with the default order from the environment *) let default_order_error_string = @@ -1573,7 +1587,9 @@ let rec nc_identical (NC_aux (nc1, _)) (NC_aux (nc2, _)) = | NC_equal (n1a, n1b), NC_equal (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b | NC_not_equal (n1a, n1b), NC_not_equal (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b | NC_bounded_ge (n1a, n1b), NC_bounded_ge (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b + | NC_bounded_gt (n1a, n1b), NC_bounded_gt (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b | NC_bounded_le (n1a, n1b), NC_bounded_le (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b + | NC_bounded_lt (n1a, n1b), NC_bounded_lt (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b | NC_or (nc1a, nc1b), NC_or (nc2a, nc2b) -> nc_identical nc1a nc2a && nc_identical nc1b nc2b | NC_and (nc1a, nc1b), NC_and (nc2a, nc2b) -> nc_identical nc1a nc2a && nc_identical nc1b nc2b | NC_true, NC_true -> true @@ -1581,42 +1597,46 @@ let rec nc_identical (NC_aux (nc1, _)) (NC_aux (nc2, _)) = | NC_set (kid1, ints1), NC_set (kid2, ints2) when List.length ints1 = List.length ints2 -> Kid.compare kid1 kid2 = 0 && List.for_all2 (fun i1 i2 -> i1 = i2) ints1 ints2 | NC_var kid1, NC_var kid2 -> Kid.compare kid1 kid2 = 0 + | NC_app (id1, args1), NC_app (id2, args2) when List.length args1 = List.length args2 -> + Id.compare id1 id2 = 0 && List.for_all2 typ_arg_identical args1 args2 | _, _ -> false -let typ_identical env typ1 typ2 = - let rec typ_identical' (Typ_aux (typ1, _)) (Typ_aux (typ2, _)) = - match typ1, typ2 with - | Typ_id v1, Typ_id v2 -> Id.compare v1 v2 = 0 - | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 = 0 - | Typ_fn (arg_typs1, ret_typ1, eff1), Typ_fn (arg_typs2, ret_typ2, eff2) - when List.length arg_typs1 = List.length arg_typs2 -> - List.for_all2 typ_identical' arg_typs1 arg_typs2 - && typ_identical' ret_typ1 ret_typ2 - && strip_effect eff1 = strip_effect eff2 - | Typ_bidir (typ1, typ2), Typ_bidir (typ3, typ4) -> - typ_identical' typ1 typ3 - && typ_identical' typ2 typ4 - | Typ_tup typs1, Typ_tup typs2 -> - begin - try List.for_all2 typ_identical' typs1 typs2 with - | Invalid_argument _ -> false - end - | Typ_app (f1, args1), Typ_app (f2, args2) -> - begin - try Id.compare f1 f2 = 0 && List.for_all2 typ_arg_identical args1 args2 with - | Invalid_argument _ -> false - end - | Typ_exist (kopts1, nc1, typ1), Typ_exist (kopts2, nc2, typ2) when List.length kopts1 = List.length kopts2 -> - List.for_all2 (fun k1 k2 -> KOpt.compare k1 k2 = 0) kopts1 kopts2 && nc_identical nc1 nc2 && typ_identical' typ1 typ2 - | _, _ -> false - and typ_arg_identical (A_aux (arg1, _)) (A_aux (arg2, _)) = - match arg1, arg2 with - | A_nexp n1, A_nexp n2 -> nexp_identical n1 n2 - | A_typ typ1, A_typ typ2 -> typ_identical' typ1 typ2 - | A_order ord1, A_order ord2 -> ord_identical ord1 ord2 - | _, _ -> false - in - typ_identical' (Env.expand_synonyms env typ1) (Env.expand_synonyms env typ2) +and typ_arg_identical (A_aux (arg1, _)) (A_aux (arg2, _)) = + match arg1, arg2 with + | A_nexp n1, A_nexp n2 -> nexp_identical n1 n2 + | A_typ typ1, A_typ typ2 -> typ_identical typ1 typ2 + | A_order ord1, A_order ord2 -> ord_identical ord1 ord2 + | A_bool nc1, A_bool nc2 -> nc_identical nc1 nc2 + | _, _ -> false + +and typ_identical (Typ_aux (typ1, _)) (Typ_aux (typ2, _)) = + match typ1, typ2 with + | Typ_id v1, Typ_id v2 -> Id.compare v1 v2 = 0 + | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 = 0 + | Typ_fn (arg_typs1, ret_typ1, eff1), Typ_fn (arg_typs2, ret_typ2, eff2) + when List.length arg_typs1 = List.length arg_typs2 -> + List.for_all2 typ_identical arg_typs1 arg_typs2 + && typ_identical ret_typ1 ret_typ2 + && strip_effect eff1 = strip_effect eff2 + | Typ_bidir (typ1, typ2), Typ_bidir (typ3, typ4) -> + typ_identical typ1 typ3 + && typ_identical typ2 typ4 + | Typ_tup typs1, Typ_tup typs2 -> + begin + try List.for_all2 typ_identical typs1 typs2 with + | Invalid_argument _ -> false + end + | Typ_app (f1, args1), Typ_app (f2, args2) -> + begin + try Id.compare f1 f2 = 0 && List.for_all2 typ_arg_identical args1 args2 with + | Invalid_argument _ -> false + end + | Typ_exist (kopts1, nc1, typ1), Typ_exist (kopts2, nc2, typ2) when List.length kopts1 = List.length kopts2 -> + List.for_all2 (fun k1 k2 -> KOpt.compare k1 k2 = 0) kopts1 kopts2 && nc_identical nc1 nc2 && typ_identical typ1 typ2 + | _, _ -> false + +let expanded_typ_identical env typ1 typ2 = + typ_identical (Env.expand_synonyms env typ1) (Env.expand_synonyms env typ2) exception Unification_error of l * string;; @@ -1712,8 +1732,12 @@ and unify_constraint l env goals (NC_aux (aux1, _) as nc1) (NC_aux (aux2, _) as merge_uvars l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) | NC_bounded_ge (n1a, n2a), NC_bounded_ge (n1b, n2b) -> merge_uvars l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) + | NC_bounded_gt (n1a, n2a), NC_bounded_gt (n1b, n2b) -> + merge_uvars l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) | NC_bounded_le (n1a, n2a), NC_bounded_le (n1b, n2b) -> merge_uvars l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) + | NC_bounded_lt (n1a, n2a), NC_bounded_lt (n1b, n2b) -> + merge_uvars l (unify_nexp l env goals n1a n1b) (unify_nexp l env goals n2a n2b) | NC_true, NC_true -> KBindings.empty | NC_false, NC_false -> KBindings.empty | _, _ -> unify_error l ("Could not unify constraints " ^ string_of_n_constraint nc1 ^ " and " ^ string_of_n_constraint nc2) @@ -1884,7 +1908,9 @@ and ambiguous_nc_vars (NC_aux (aux, _)) = match aux with | NC_and (nc1, nc2) -> KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2) | NC_bounded_le (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) + | NC_bounded_lt (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) | NC_bounded_ge (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) + | NC_bounded_gt (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) | NC_equal (n1, n2) | NC_not_equal (n1, n2) -> KidSet.union (ambiguous_nexp_vars n1) (ambiguous_nexp_vars n2) | _ -> KidSet.empty @@ -1968,7 +1994,9 @@ and kid_order_constraint kind_map (NC_aux (aux, l) as nc) = ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) | NC_var _ | NC_set _ -> ([], kind_map) | NC_true | NC_false -> ([], kind_map) - | NC_equal (n1, n2) | NC_not_equal (n1, n2) | NC_bounded_le (n1, n2) | NC_bounded_ge (n1, n2) -> + | NC_equal (n1, n2) | NC_not_equal (n1, n2) + | NC_bounded_le (n1, n2) | NC_bounded_ge (n1, n2) + | NC_bounded_lt (n1, n2) | NC_bounded_gt (n1, n2) -> let ord1, kind_map = kid_order_nexp kind_map n1 in let ord2, kind_map = kid_order_nexp kind_map n2 in (ord1 @ ord2, kind_map) @@ -2017,7 +2045,7 @@ let rec alpha_equivalent env typ1 typ2 = counter := 0; let typ2 = relabel (Env.expand_synonyms env typ2) in typ_debug (lazy ("Alpha equivalence for " ^ string_of_typ typ1 ^ " and " ^ string_of_typ typ2)); - if typ_identical env typ1 typ2 + if typ_identical typ1 typ2 then (typ_debug (lazy "alpha-equivalent"); true) else (typ_debug (lazy "Not alpha-equivalent"); false) @@ -2257,7 +2285,9 @@ and rewrite_nc_aux l env = let mk_exp exp = mk_exp ~loc:l exp in function | NC_bounded_ge (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id ">=", mk_exp (E_sizeof n2)) + | NC_bounded_gt (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id ">", mk_exp (E_sizeof n2)) | NC_bounded_le (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "<=", mk_exp (E_sizeof n2)) + | NC_bounded_lt (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "<", mk_exp (E_sizeof n2)) | NC_equal (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "==", mk_exp (E_sizeof n2)) | NC_not_equal (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "!=", mk_exp (E_sizeof n2)) | NC_and (nc1, nc2) -> E_app_infix (rewrite_nc env nc1, mk_id "&", rewrite_nc env nc2) @@ -4864,7 +4894,7 @@ let mk_val_spec env typq typ id = let check_tannotopt env typq ret_typ = function | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> () | Typ_annot_opt_aux (Typ_annot_opt_some (annot_typq, annot_ret_typ), l) -> - if typ_identical env ret_typ annot_ret_typ + if expanded_typ_identical env ret_typ annot_ret_typ then () else typ_error env l (string_of_bind (typq, ret_typ) ^ " and " ^ string_of_bind (annot_typq, annot_ret_typ) ^ " do not match between function and val spec") @@ -4954,7 +4984,7 @@ let check_mapdef env (MD_aux (MD_mapping (id, tannot_opt, mapcls), (l, _)) as md begin match tannot_opt with | Typ_annot_opt_aux (Typ_annot_opt_none, _) -> () | Typ_annot_opt_aux (Typ_annot_opt_some (annot_typq, annot_typ), l) -> - if typ_identical env typ annot_typ then () + if expanded_typ_identical env typ annot_typ then () else typ_error env l (string_of_bind (quant, typ) ^ " and " ^ string_of_bind (annot_typq, annot_typ) ^ " do not match between mapping and val spec") end; typ_debug (lazy ("Checking mapdef " ^ string_of_id id ^ " has type " ^ string_of_bind (quant, typ))); @@ -4973,6 +5003,23 @@ let check_mapdef env (MD_aux (MD_mapping (id, tannot_opt, mapcls), (l, _)) as md else typ_error env l ("Mapping not pure (or escape only): " ^ string_of_effect eff ^ " found") +let rec warn_if_unsafe_cast l env = function + | Typ_aux (Typ_fn (arg_typs, ret_typ, _), _) -> + List.iter (warn_if_unsafe_cast l env) arg_typs; + warn_if_unsafe_cast l env ret_typ + | Typ_aux (Typ_id id, _) when string_of_id id = "bool" -> () + | Typ_aux (Typ_id id, _) when Env.is_enum id env -> () + | Typ_aux (Typ_id id, _) when string_of_id id = "string" -> + Reporting.warn "Unsafe string cast" l + "A cast X -> string is unsafe, as it can cause 'x : X == y : X' to be checked as 'eq_string(cast(x), cast(y))'" + (* If we have a cast to an existential, it's probably done on + purpose and we want to avoid false positives for warnings. *) + | Typ_aux (Typ_exist _, _) -> () + | typ when is_bitvector_typ typ -> () + | typ when is_bit_typ typ -> () + | typ -> + Reporting.warn ("Potentially unsafe cast involving " ^ string_of_typ typ) l "" + (* Checking a val spec simply adds the type as a binding in the context. We have to destructure the various kinds of val specs, but the difference is irrelevant for the typechecker. *) @@ -4981,8 +5028,9 @@ let check_val_spec env (VS_aux (vs, (l, _))) = let vs, id, typq, typ, env = match vs with | VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), ts_l) as typschm, id, exts, is_cast) -> typ_print (lazy (Util.("Check val spec " |> cyan |> clear) ^ string_of_id id ^ " : " ^ string_of_typschm typschm)); + wf_typschm env typschm; let env = Env.add_extern id exts env in - let env = if is_cast then Env.add_cast id env else env in + let env = if is_cast then (warn_if_unsafe_cast l env (Env.expand_synonyms env typ); Env.add_cast id env) else env in let typq', typ' = expand_bind_synonyms ts_l env (typq, typ) in (* !opt_expand_valspec controls whether the actual valspec in the AST is expanded, the val_spec type stored in the |
