summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorThomas Bauereiss2019-03-15 14:51:00 +0000
committerThomas Bauereiss2019-03-15 18:47:30 +0000
commitabab0b23aef8404fc62d4f856df74597a5d86a18 (patch)
tree8b536af58d4f2e57f5509da650aa692cc3b22dfa /src
parent541c1880d31a47302fea48725bd7247d374828d6 (diff)
Various monomorphisation tweaks and fixes
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml155
-rw-r--r--src/pretty_print_lem.ml4
-rw-r--r--src/rewrites.ml12
-rw-r--r--src/sail_lib.ml1
-rw-r--r--src/type_check.ml11
-rw-r--r--src/type_check.mli3
6 files changed, 138 insertions, 48 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index fdc20932..5a7a72a6 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -690,13 +690,19 @@ let split_defs all_errors splits defs =
| Typ_app (Id_aux (Id "vector",_), [A_aux (A_nexp len,_);_;A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_)]) ->
(match len with
- | Nexp_aux (Nexp_constant sz,_) ->
- let lits = make_vectors (Big_int.to_int sz) in
- List.map (fun lit ->
- P_aux (P_lit lit,(l,annot)),
- [var,E_aux (E_lit lit,(new_l,annot))],[],[]) lits
+ | Nexp_aux (Nexp_constant sz,_) when Big_int.greater_equal sz Big_int.zero ->
+ let sz = Big_int.to_int sz in
+ let num_lits = Big_int.pow_int (Big_int.of_int 2) sz in
+ (* Check that split size is within limits before generating the list of literals *)
+ if (Big_int.less_equal num_lits (Big_int.of_int size_set_limit)) then
+ let lits = make_vectors sz in
+ List.map (fun lit ->
+ P_aux (P_lit lit,(l,annot)),
+ [var,E_aux (E_lit lit,(new_l,annot))],[],[]) lits
+ else
+ cannot ("bitvector length outside limit, " ^ string_of_nexp len)
| _ ->
- cannot ("length not constant, " ^ string_of_nexp len)
+ cannot ("length not constant and positive, " ^ string_of_nexp len)
)
(* set constrained numbers *)
| Typ_app (Id_aux (Id "atom",_), [A_aux (A_nexp (Nexp_aux (value,_) as nexp),_)]) ->
@@ -1289,6 +1295,11 @@ let rewrite_size_parameters env (Defs defs) =
let pat,guard,exp,pannot = destruct_pexp pexp in
let env = env_of_annot (l,ann) in
let _, typ = Env.get_val_spec_orig id env in
+ let already_visible_nexps =
+ NexpSet.union
+ (Pretty_print_lem.lem_nexps_of_typ typ)
+ (Pretty_print_lem.typeclass_nexps typ)
+ in
let types = match typ with
| Typ_aux (Typ_fn (arg_typs,_,_),_) -> List.map (Env.expand_synonyms env) arg_typs
| _ -> raise (Reporting.err_unreachable l __POS__ "Function clause does not have a function type")
@@ -1299,11 +1310,14 @@ let rewrite_size_parameters env (Defs defs) =
Typ_aux (Typ_app(Id_aux (Id "range",_),
[A_aux (A_nexp nexp,_);
A_aux (A_nexp nexp',_)]),_)
- when Nexp.compare nexp nexp' = 0 && not (NexpMap.mem nexp nmap) ->
- NexpMap.add nexp i nmap
+ when Nexp.compare nexp nexp' = 0 && not (NexpMap.mem nexp nmap) &&
+ not (NexpSet.mem nexp already_visible_nexps) ->
+ (* Split integer variables if the nexp is not already available via a bitvector length *)
+ NexpMap.add nexp i nmap
| Typ_aux (Typ_app(Id_aux (Id "atom", _),
[A_aux (A_nexp nexp,_)]), _)
- when not (NexpMap.mem nexp nmap) ->
+ when not (NexpMap.mem nexp nmap) &&
+ not (NexpSet.mem nexp already_visible_nexps) ->
NexpMap.add nexp i nmap
| _ -> nmap
in (i+1,nmap)
@@ -2172,6 +2186,11 @@ let rec analyse_exp fn_id env assigns (E_aux (e,(l,annot)) as exp) =
| E_constraint nc ->
(deps_of_nc env.kid_deps nc, assigns, empty)
in
+ let deps =
+ match destruct_atom_bool (env_of exp) (typ_of exp) with
+ | Some nc -> dmerge deps (deps_of_nc env.kid_deps nc)
+ | None -> deps
+ in
let r =
(* Check for bitvector types with parametrised sizes *)
match destruct_tannot annot with
@@ -2450,11 +2469,14 @@ let rec sets_from_assert e =
| None -> KBindings.empty)
| _ -> KBindings.empty
in
- match e with
- | E_aux (E_app (Id_aux (Id "and_bool",_),[e1;e2]),_) ->
- merge_set_asserts_by_kid (sets_from_assert e1) (sets_from_assert e2)
- | E_aux (E_constraint nc,_) -> sets_from_nc nc
- | _ -> set_from_or_exps e
+ match destruct_atom_bool (env_of e) (typ_of e) with
+ | Some nc -> sets_from_nc nc
+ | None ->
+ match e with
+ | E_aux (E_app (Id_aux (Id "and_bool",_),[e1;e2]),_) ->
+ merge_set_asserts_by_kid (sets_from_assert e1) (sets_from_assert e2)
+ | E_aux (E_constraint nc,_) -> sets_from_nc nc
+ | _ -> set_from_or_exps e
(* Find all the easily reached set assertions in a function body, to use as
case splits. Note that this should be mirrored in stop_at_false_assertions,
@@ -2670,12 +2692,17 @@ let rec rewrite_app env typ (id,args) =
let is_append = is_id env (Id "append") in
let is_subrange = is_id env (Id "vector_subrange") in
let is_slice = is_id env (Id "slice") in
- let is_zeros = is_id env (Id "Zeros") in
+ let is_zeros id =
+ is_id env (Id "Zeros") id || is_id env (Id "zeros") id ||
+ is_id env (Id "sail_zeros") id
+ in
+ let is_ones = is_id env (Id "Ones") in
let is_zero_extend =
is_id env (Id "ZeroExtend") id ||
is_id env (Id "zero_extend") id || is_id env (Id "sail_zero_extend") id ||
is_id env (Id "mips_zero_extend") id
in
+ let is_truncate = is_id env (Id "truncate") id in
let mk_exp e = E_aux (e, (Unknown, empty_tannot)) in
let try_cast_to_typ (E_aux (e,(l, _)) as exp) =
let (size,order,bittyp) = vector_typ_args_of (Env.base_typ_of env typ) in
@@ -2777,6 +2804,17 @@ let rec rewrite_app env typ (id,args) =
(E_aux (E_app (mk_id "slice_slice_concat",
[vector1; start1; length1; vector2; start2; length2]),(Unknown,empty_tannot)))
+ (* variable-slice @ local-var *)
+ | [E_aux (E_app (slice1,
+ [vector1; start1; length1]),_);
+ (E_aux (E_id _,_) as vector2)]
+ when is_slice slice1 && not (is_constant length1) ->
+ let start2 = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in
+ let length2 = mk_exp (E_app (mk_id "length", [vector2])) in
+ try_cast_to_typ
+ (E_aux (E_app (mk_id "slice_slice_concat",
+ [vector1; start1; length1; vector2; start2; length2]),(Unknown,empty_tannot)))
+
| [E_aux (E_app (append1,
[e1;
E_aux (E_app (slice1, [vector1; start1; length1]),_)]),_);
@@ -2805,13 +2843,24 @@ let rec rewrite_app env typ (id,args) =
[vector1; start1; length1; length2]),(Unknown,empty_tannot))]),
(Unknown,empty_tannot)))
end
+
+ (* known-length @ (known-length @ var-length) *)
+ | [e1; E_aux (E_app (append1, [e2; e3]), _)]
+ when is_append append1 && is_constant_vec_typ env (typ_of e1) &&
+ is_constant_vec_typ env (typ_of e2) &&
+ not (is_constant_vec_typ env (typ_of e3)) ->
+ let (size1,order,bittyp) = vector_typ_args_of (Env.base_typ_of env (typ_of e1)) in
+ let (size2,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of e2)) in
+ let size12 = nexp_simp (nsum size1 size2) in
+ let tannot12 = mk_tannot env (vector_typ size12 order bittyp) no_effect in
+ E_app (id, [E_aux (E_app (append1, [e1; e2]), (Unknown, tannot12)); e3])
+
| _ -> E_app (id,args)
- else if is_id env (Id "eq_vec") id || is_id env (Id "neq_vec") id then
+ else if is_id env (Id "eq_bits") id || is_id env (Id "neq_bits") id then
(* variable-range == variable_range *)
- let is_subrange = is_id env (Id "vector_subrange") in
let wrap e =
- if is_id env (Id "neq_vec") id
+ if is_id env (Id "neq_bits") id
then E_app (mk_id "not_bool", [mk_exp e])
else e
in
@@ -2867,11 +2916,7 @@ let rec rewrite_app env typ (id,args) =
E_app (mk_id "is_ones_slice", [vector1; start1; len1])
| _ -> E_app (id,args)
- else if is_zero_extend then
- let is_subrange = is_id env (Id "vector_subrange") in
- let is_slice = is_id env (Id "slice") in
- let is_zeros = is_id env (Id "Zeros") in
- let is_ones = is_id env (Id "Ones") in
+ else if is_zero_extend || is_truncate then
let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in
match List.filter (fun arg -> not (is_number (typ_of arg))) args with
| [E_aux (E_app (append1,
@@ -2881,10 +2926,18 @@ let rec rewrite_app env typ (id,args) =
-> try_cast_to_typ (rewrap (E_app (mk_id "place_subrange", length_arg @ [vector1; start1; end1; len1])))
| [E_aux (E_app (append1,
- [E_aux (E_app (slice1, [vector1; start1; length1]), _);
+ [vector1;
E_aux (E_app (zeros1, [length2]),_)]),_)]
- when is_slice slice1 && is_zeros zeros1 && is_append append1
- -> try_cast_to_typ (rewrap (E_app (mk_id "place_slice", length_arg @ [vector1; start1; length1; length2])))
+ when is_constant_vec_typ env (typ_of vector1) && is_zeros zeros1 && is_append append1
+ -> let (vector1, start1, length1) =
+ match vector1 with
+ | E_aux (E_app (slice1, [vector1; start1; length1]), _) ->
+ (vector1, start1, length1)
+ | _ ->
+ let (length1,_,_) = vector_typ_args_of (Env.base_typ_of env (typ_of vector1)) in
+ (vector1, mk_exp (E_lit (mk_lit (L_num (Big_int.zero)))), mk_exp (E_sizeof length1))
+ in
+ try_cast_to_typ (rewrap (E_app (mk_id "place_slice", length_arg @ [vector1; start1; length1; length2])))
(* If we've already rewritten to slice_slice_concat or subrange_subrange_concat,
we can just drop the zero extension because those functions can do it
@@ -2902,10 +2955,19 @@ let rec rewrite_app env typ (id,args) =
| [E_aux (E_app (ones, [len1]),_)] when is_ones ones ->
try_cast_to_typ (rewrap (E_app (mk_id "zext_ones", length_arg @ [len1])))
+ | [E_aux (E_app (replicate_bits, [E_aux (E_lit (L_aux (L_bin "1", _)), _); len1]), _)]
+ when is_id env (Id "replicate_bits") replicate_bits ->
+ let start1 = mk_exp (E_lit (mk_lit (L_num Big_int.zero))) in
+ try_cast_to_typ (rewrap (E_app (mk_id "slice_mask", length_arg @ [start1; len1])))
+
+ | [E_aux (E_app (zeros, [len1]),_)]
+ | [E_aux (E_cast (_, E_aux (E_app (zeros, [len1]),_)), _)]
+ when is_zeros zeros ->
+ try_cast_to_typ (rewrap (E_app (id, length_arg)))
+
| _ -> E_app (id,args)
else if is_id env (Id "SignExtend") id || is_id env (Id "sign_extend") id then
- let is_slice = is_id env (Id "slice") in
let length_arg = List.filter (fun arg -> is_number (typ_of arg)) args in
match List.filter (fun arg -> not (is_number (typ_of arg))) args with
| [E_aux (E_app (slice1, [vector1; start1; length1]),_)]
@@ -2947,8 +3009,6 @@ let rec rewrite_app env typ (id,args) =
| _ -> E_app (id, args)
else if is_id env (Id "UInt") id || is_id env (Id "unsigned") id then
- let is_slice = is_id env (Id "slice") in
- let is_subrange = is_id env (Id "vector_subrange") in
match args with
| [E_aux (E_app (slice1, [vector1; start1; length1]),_)]
when is_slice slice1 && not (is_constant length1) ->
@@ -3032,7 +3092,7 @@ let check_for_spec env name =
(* These functions add cast functions across case splits, so that when a
bitvector size becomes known in sail, the generated Lem code contains a
function call to change mword 'n to (say) mword ty16, and vice versa. *)
-let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
+let make_bitvector_cast_fns cast_name top_env env quant_kids src_typ target_typ =
let genunk = Generated Unknown in
let fresh =
let counter = ref 0 in
@@ -3056,7 +3116,7 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
Typ_app (Id_aux (Id "vector",_) as t_id,
[A_aux (A_nexp size',l_size'); t_ord;
A_aux (A_typ (Typ_aux (Typ_id (Id_aux (Id "bit",_)),_)),_) as t_bit]) -> begin
- match simplify_size_nexp env quant_kids size, simplify_size_nexp env quant_kids size' with
+ match simplify_size_nexp env quant_kids size, simplify_size_nexp top_env quant_kids size' with
| Some size, Some size' when Nexp.compare size size' <> 0 ->
let var = fresh () in
let tar_typ' = Typ_aux (Typ_app (t_id, [A_aux (A_nexp size',l_size');t_ord;t_bit]),
@@ -3112,7 +3172,7 @@ let make_bitvector_cast_fns cast_name env quant_kids src_typ target_typ =
(* TODO: bound vars *)
let make_bitvector_env_casts env quant_kids (kid,i) exp =
- let mk_cast var typ exp = (fst (make_bitvector_cast_fns "bitvector_cast_in" env quant_kids typ (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
+ let mk_cast var typ exp = (fst (make_bitvector_cast_fns "bitvector_cast_in" env env quant_kids typ (subst_kids_typ (KBindings.singleton kid (nconstant i)) typ))) var exp in
let locals = Env.get_locals env in
Bindings.fold (fun var (mut,typ) exp ->
if mut = Immutable then mk_cast var typ exp else exp) locals exp
@@ -3157,7 +3217,7 @@ let make_bitvector_cast_exp cast_name cast_env quant_kids typ target_typ exp =
let tgt_arg_typ = infer_arg_typ (env_of exp) f l target_typ in
E_aux (E_app (f,[aux arg (src_arg_typ, tgt_arg_typ)]),(l,ann))
| _ ->
- (snd (make_bitvector_cast_fns cast_name cast_env quant_kids typ target_typ)) exp
+ (snd (make_bitvector_cast_fns cast_name cast_env (env_of exp) quant_kids typ target_typ)) exp
in
aux exp (typ, target_typ)
@@ -3287,9 +3347,10 @@ let add_bitvector_casts (Defs defs) =
{ id_exp_alg with
e_aux = rewrite_aux } exp
in
- let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),fcl_ann)) =
+ let rewrite_funcl (FCL_aux (FCL_Funcl (id,pexp),((l,_) as fcl_ann))) =
let fcl_env = env_of_annot fcl_ann in
let (tq,typ) = Env.get_val_spec_orig id fcl_env in
+ let fun_env = add_typquant l tq fcl_env in
let quant_kids = List.map kopt_kid (List.filter is_int_kopt (quant_kopts tq)) in
let ret_typ =
match typ with
@@ -3300,11 +3361,10 @@ let add_bitvector_casts (Defs defs) =
" is not a function type"))
in
let pat,guard,body,annot = destruct_pexp pexp in
- let body_env = env_of body in
- let body = rewrite_body id quant_kids body_env ret_typ body in
+ let body = rewrite_body id quant_kids fun_env ret_typ body in
(* Also add a cast around the entire function clause body, if necessary *)
let body =
- make_bitvector_cast_exp "bitvector_cast_out" fcl_env quant_kids (fill_in_type body_env (typ_of body)) ret_typ body
+ make_bitvector_cast_exp "bitvector_cast_out" fun_env quant_kids (fill_in_type (env_of body) (typ_of body)) ret_typ body
in
let pexp = construct_pexp (pat,guard,body,annot) in
FCL_aux (FCL_Funcl (id,pexp),fcl_ann)
@@ -3470,7 +3530,7 @@ let rewrite_toplevel_nexps (Defs defs) =
in
(* Changing types in the body confuses simple sizeof rewriting, so turn it
off for now *)
- (* let rewrite_typ_in_body env nexp_map typ =
+ let rewrite_typ_in_body env nexp_map typ =
let rec aux (Typ_aux (t,l) as typ_full) =
match t with
| Typ_tup typs -> Typ_aux (Typ_tup (List.map aux typs),l)
@@ -3515,10 +3575,17 @@ let rewrite_toplevel_nexps (Defs defs) =
| P_typ (typ,p') -> P_aux (P_typ (rewrite_typ_in_body (env_of_annot ann) nexp_map typ,p'),ann)
| _ -> P_aux (p,ann)
in
+ let rewrite_one_lexp nexp_map (lexp, ann) =
+ match lexp with
+ | LEXP_cast (typ, id) ->
+ LEXP_aux (LEXP_cast (rewrite_typ_in_body (env_of_annot ann) nexp_map typ, id), ann)
+ | _ -> LEXP_aux (lexp, ann)
+ in
let rewrite_body nexp_map pexp =
let open Rewriter in
fold_pexp { id_exp_alg with
e_aux = rewrite_one_exp nexp_map;
+ lEXP_aux = rewrite_one_lexp nexp_map;
pat_alg = { id_pat_alg with p_aux = rewrite_one_pat nexp_map }
} pexp
in
@@ -3526,25 +3593,29 @@ let rewrite_toplevel_nexps (Defs defs) =
match Bindings.find id spec_map with
| nexp_map -> FCL_aux (FCL_Funcl (id,rewrite_body nexp_map pexp),ann)
| exception Not_found -> funcl
- in *)
+ in
let rewrite_def spec_map def =
match def with
| DEF_spec vs -> (match rewrite_valspec vs with
| None -> spec_map, def
| Some (id, nexp_map, vs) -> Bindings.add id nexp_map spec_map, DEF_spec vs)
- (* | DEF_fundef (FD_aux (FD_function (recopt,_,eff,funcls),ann)) ->
+ | DEF_fundef (FD_aux (FD_function (recopt,_,eff,funcls),ann)) ->
(* Type annotations on function definitions will have been turned into
valspecs by type checking, so it should be safe to drop them rather
than updating them. *)
let tann = Typ_annot_opt_aux (Typ_annot_opt_none,Generated Unknown) in
spec_map,
- DEF_fundef (FD_aux (FD_function (recopt,tann,eff,List.map (rewrite_funcl spec_map) funcls),ann)) *)
+ DEF_fundef (FD_aux (FD_function (recopt,tann,eff,List.map (rewrite_funcl spec_map) funcls),ann))
| _ -> spec_map, def
in
let _, defs = List.fold_left (fun (spec_map,t) def ->
let spec_map, def = rewrite_def spec_map def in
(spec_map, def::t)) (Bindings.empty, []) defs
- in Defs (List.rev defs)
+ in
+ (* Allow use of div and mod in nexp rewriting during later typechecking passes
+ to help prove equivalences such as (8 * 'n) = 'p8_times_n# *)
+ Type_check.opt_smt_div := true;
+ Defs (List.rev defs)
type options = {
auto : bool;
diff --git a/src/pretty_print_lem.ml b/src/pretty_print_lem.ml
index 759c7637..933925da 100644
--- a/src/pretty_print_lem.ml
+++ b/src/pretty_print_lem.ml
@@ -927,7 +927,9 @@ let doc_exp_lem, doc_let_lem =
let b = match e1 with E_aux (E_if _,_) -> true | _ -> false in
let middle =
match fst (untyp_pat pat) with
- | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> string ">>"
+ | P_aux (P_wild,_) | P_aux (P_typ (_, P_aux (P_wild, _)), _)
+ when is_unit_typ (typ_of_pat pat) ->
+ string ">>"
| P_aux (P_tup _, _)
when not (IdSet.mem (mk_id "varstup") (find_e_ids e2)) ->
(* Work around indentation issues in Lem when translating
diff --git a/src/rewrites.ml b/src/rewrites.ml
index 8bfbc351..502b910c 100644
--- a/src/rewrites.ml
+++ b/src/rewrites.ml
@@ -2460,14 +2460,20 @@ let rewrite_defs_letbind_effects env =
k (rewrap (E_throw exp')))
| E_internal_plet _ -> failwith "E_internal_plet should not be here yet" in
- let rewrite_fun _ (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),fdannot)) =
+ let rewrite_fun _ (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),fdannot) as fd) =
(* let propagate_funcl_effect (FCL_aux (FCL_Funcl(id, pexp), (l, a))) =
let pexp, eff = propagate_pexp_effect pexp in
FCL_aux (FCL_Funcl(id, pexp), (l, add_effect_annot a eff))
in
let funcls = List.map propagate_funcl_effect funcls in *)
+ let effectful_vs =
+ match Env.get_val_spec (id_of_fundef fd) env with
+ | _, Typ_aux (Typ_fn (_, _, effs), _) -> effectful_effs effs
+ | _, _ -> false
+ | exception Type_error _ -> false
+ in
let effectful_funcl (FCL_aux (FCL_Funcl(_, pexp), _)) = effectful_pexp pexp in
- let newreturn = List.exists effectful_funcl funcls in
+ let newreturn = effectful_vs || List.exists effectful_funcl funcls in
let rewrite_funcl (FCL_aux (FCL_Funcl(id,pexp),annot)) =
let _ = reset_fresh_name_counter () in
FCL_aux (FCL_Funcl (id,n_pexp newreturn pexp (fun x -> x)),annot)
@@ -4664,13 +4670,13 @@ let rewrite_defs_lem = [
("mapping_builtins", rewrite_defs_mapping_patterns);
("mono_rewrites", mono_rewrites);
("recheck_defs", if_mono recheck_defs);
+ ("rewrite_undefined", rewrite_undefined_if_gen false);
("rewrite_toplevel_nexps", if_mono rewrite_toplevel_nexps);
("monomorphise", if_mono monomorphise);
("recheck_defs", if_mwords recheck_defs);
("add_bitvector_casts", if_mwords (fun _ -> Monomorphise.add_bitvector_casts));
("rewrite_atoms_to_singletons", if_mono (fun _ -> Monomorphise.rewrite_atoms_to_singletons));
("recheck_defs", if_mwords recheck_defs);
- ("rewrite_undefined", rewrite_undefined_if_gen false);
("rewrite_defs_vector_string_pats_to_bit_list", rewrite_defs_vector_string_pats_to_bit_list);
("remove_not_pats", rewrite_defs_not_pats);
("remove_impossible_int_cases", Constant_propagation.remove_impossible_int_cases);
diff --git a/src/sail_lib.ml b/src/sail_lib.ml
index d1a21b73..4bb004bf 100644
--- a/src/sail_lib.ml
+++ b/src/sail_lib.ml
@@ -695,6 +695,7 @@ let string_of_zbit = function
| B1 -> "1"
let string_of_znat n = Big_int.to_string n
let string_of_zint n = Big_int.to_string n
+let string_of_zimplicit n = Big_int.to_string n
let string_of_zunit () = "()"
let string_of_zbool = function
| true -> "true"
diff --git a/src/type_check.ml b/src/type_check.ml
index 31a9370f..5aafe601 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -76,6 +76,9 @@ let opt_expand_valspec = ref true
the SMT solver to use non-linear arithmetic. *)
let opt_smt_linearize = ref false
+(* Allow use of div and mod when rewriting nexps *)
+let opt_smt_div = ref false
+
let depth = ref 0
let rec indent n = match n with
@@ -1775,9 +1778,9 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au
mod(m, C) = 0 && C != 0 --> (C * n = m <--> n = m / C)
- to help us unify multiplications and divisions.
+ to help us unify multiplications and divisions. *)
let valid n c = prove __POS__ env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove __POS__ env (nc_neq c (nint 0)) in
- if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then
+ (*if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then
unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b])
else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then
unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) *)
@@ -1793,6 +1796,8 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au
unify_nexp l env goals n1b (nconstant (Big_int.div c2 c1))
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
end
+ | Nexp_var kid when (not (KidSet.mem kid goals)) && valid nexp2 n1a && !opt_smt_div ->
+ unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a])
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
end
else if KidSet.is_empty (nexp_frees n1b) then
@@ -1800,6 +1805,8 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au
match nexp_aux2 with
| Nexp_times (n2a, n2b) when prove __POS__ env (NC_aux (NC_equal (n1b, n2b), Parse_ast.Unknown)) ->
unify_nexp l env goals n1a n2a
+ | Nexp_var kid when (not (KidSet.mem kid goals)) && valid nexp2 n1b && !opt_smt_div ->
+ unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b])
| _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
end
else unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2)
diff --git a/src/type_check.mli b/src/type_check.mli
index 5333d02d..737e714e 100644
--- a/src/type_check.mli
+++ b/src/type_check.mli
@@ -77,6 +77,9 @@ val opt_expand_valspec : bool ref
the SMT solver to use non-linear arithmetic. *)
val opt_smt_linearize : bool ref
+(** Allow use of div and mod when rewriting nexps *)
+val opt_smt_div : bool ref
+
(** {2 Type errors} *)
type type_error =