diff options
| author | Kathy Gray | 2015-08-14 10:30:30 +0100 |
|---|---|---|
| committer | Kathy Gray | 2015-08-14 10:30:30 +0100 |
| commit | d698593f14334811f3230d385737cc2bc96b5a63 (patch) | |
| tree | e78280127bc9e1736c3625168063ef3f3a5d8bd4 /src | |
| parent | d4d2e262f96a8eef543c017c8df08c25f2715118 (diff) | |
Steps towards making constraint solver smarter
Diffstat (limited to 'src')
| -rw-r--r-- | src/pretty_print.ml | 6 | ||||
| -rw-r--r-- | src/type_check.ml | 22 | ||||
| -rw-r--r-- | src/type_internal.ml | 313 | ||||
| -rw-r--r-- | src/type_internal.mli | 3 |
4 files changed, 242 insertions, 102 deletions
diff --git a/src/pretty_print.ml b/src/pretty_print.ml index 266e8141..182c1929 100644 --- a/src/pretty_print.ml +++ b/src/pretty_print.ml @@ -65,14 +65,14 @@ let pp_format_var (Kid_aux(Var v,_)) = v let rec pp_format_l_lem = function | Parse_ast.Unknown -> "Unknown" -(* | _ -> "Unknown"*) - | Parse_ast.Int(s,None) -> "(Int \"" ^ s ^ "\" Nothing)" + | _ -> "Unknown" +(* | Parse_ast.Int(s,None) -> "(Int \"" ^ s ^ "\" Nothing)" | Parse_ast.Int(s,(Some l)) -> "(Int \"" ^ s ^ "\" (Just " ^ (pp_format_l_lem l) ^ "))" | Parse_ast.Range(p1,p2) -> "(Range \"" ^ p1.Lexing.pos_fname ^ "\" " ^ (string_of_int p1.Lexing.pos_lnum) ^ " " ^ (string_of_int (p1.Lexing.pos_cnum - p1.Lexing.pos_bol)) ^ " " ^ (string_of_int p2.Lexing.pos_lnum) ^ " " ^ - (string_of_int (p2.Lexing.pos_cnum - p2.Lexing.pos_bol)) ^ ")" + (string_of_int (p2.Lexing.pos_cnum - p2.Lexing.pos_bol)) ^ ")"*) let pp_lem_l ppf l = base ppf (pp_format_l_lem l) diff --git a/src/type_check.ml b/src/type_check.ml index 8ee47e10..6b3ebbb1 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -818,8 +818,8 @@ let rec check_exp envs (imp_param:nexp option) (expect_t:t) (E_aux(e,(l,annot)): (*TOTHINK Possibly I should first consistency check else and then, with Guarantee, then check against expect_t with Require*) let then_t',then_c' = type_consistent (Expr l) d_env Require true then_t expect_t in let else_t',else_c' = type_consistent (Expr l) d_env Require true else_t expect_t in - let t_cs = CondCons((Expr l),c1,then_c@then_c') in - let e_cs = CondCons((Expr l),[],else_c@else_c') in + let t_cs = CondCons((Expr l),Positive,c1,then_c@then_c') in + let e_cs = CondCons((Expr l),Negative,[],else_c@else_c') in (E_aux(E_if(cond',then',else'),(l,simple_annot expect_t)), expect_t,Envmap.intersect_merge (tannot_merge (Expr l) d_env true) then_env else_env,[t_cs;e_cs], merge_bounds then_bs else_bs, (*TODO Should be an intersecting merge*) @@ -827,8 +827,8 @@ let rec check_exp envs (imp_param:nexp option) (expect_t:t) (E_aux(e,(l,annot)): | _ -> let then',then_t,then_env,then_c,then_bs,then_ef = check_exp envs imp_param expect_t then_ in let else',else_t,else_env,else_c,else_bs,else_ef = check_exp envs imp_param expect_t else_ in - let t_cs = CondCons((Expr l),c1,then_c) in - let e_cs = CondCons((Expr l),[],else_c) in + let t_cs = CondCons((Expr l),Positive,c1,then_c) in + let e_cs = CondCons((Expr l),Negative,[],else_c) in (E_aux(E_if(cond',then',else'),(l,simple_annot expect_t)), expect_t,Envmap.intersect_merge (tannot_merge (Expr l) d_env true) then_env else_env,[t_cs;e_cs], merge_bounds then_bs else_bs, @@ -1225,7 +1225,8 @@ let rec check_exp envs (imp_param:nexp option) (expect_t:t) (E_aux(e,(l,annot)): | Tapp("register",[TA_typ t]) -> t | _ -> t' in (*let _ = Printf.eprintf "Type of pattern after register check %s\n" (t_to_string t') in*) - let (pexps',t,cs',ef') = check_cases envs imp_param t' expect_t pexps in + let (pexps',t,cs',ef') = + check_cases envs imp_param t' expect_t (if (List.length pexps) = 1 then Solo else Switch) pexps in (E_aux(E_case(e',pexps'),(l,simple_annot t)),t,t_env,cs@cs',nob,union_effects ef ef') | E_let(lbind,body) -> let (lb',t_env',cs,b_env',ef) = (check_lbind envs imp_param false Emp_local lbind) in @@ -1260,7 +1261,7 @@ and check_block envs imp_param expect_t exps:((tannot exp) list * tannot * nexp_ merge_bounds b_env' b_env, tp_env)) imp_param expect_t exps in ((e'::exps'),annot',sc@sc',t,union_effects ef ef') -and check_cases envs imp_param check_t expect_t pexps : ((tannot pexp) list * typ * nexp_range list * effect) = +and check_cases envs imp_param check_t expect_t kind pexps : ((tannot pexp) list * typ * nexp_range list * effect) = let (Env(d_env,t_env,b_env,tp_env)) = envs in match pexps with | [] -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "switch with no cases found") @@ -1270,7 +1271,7 @@ and check_cases envs imp_param check_t expect_t pexps : ((tannot pexp) list * ty check_exp (Env(d_env, Envmap.union_merge (tannot_merge (Expr l) d_env true) t_env env, merge_bounds b_env bounds, tp_env)) imp_param expect_t exp in - let cs = [CondCons(Expr l, cs_p, cs_e)] in + let cs = [CondCons(Expr l,kind, cs_p, cs_e)] in [Pat_aux(Pat_exp(pat',e),(l,cons_ef_annot t cs ef))],t,cs,ef | ((Pat_aux(Pat_exp(pat,exp),(l,annot)))::pexps) -> let pat',env,cs_p,bounds,u = check_pattern envs Emp_local check_t pat in @@ -1278,8 +1279,8 @@ and check_cases envs imp_param check_t expect_t pexps : ((tannot pexp) list * ty check_exp (Env(d_env, Envmap.union_merge (tannot_merge (Expr l) d_env true) t_env env, merge_bounds b_env bounds, tp_env)) imp_param expect_t exp in - let cs = CondCons(Expr l,cs_p,cs_e) in - let (pes,tl,csl,efl) = check_cases envs imp_param check_t expect_t pexps in + let cs = CondCons(Expr l,kind,cs_p,cs_e) in + let (pes,tl,csl,efl) = check_cases envs imp_param check_t expect_t kind pexps in ((Pat_aux(Pat_exp(pat',e),(l,cons_ef_annot t [cs] ef)))::pes,tl,cs::csl,union_effects efl ef) and check_lexp envs imp_param is_top (LEXP_aux(lexp,(l,annot))) @@ -1727,6 +1728,7 @@ let check_fundef envs (FD_aux(FD_function(recopt,tannotopt,effectopt,funcls),(l, let p_t = new_t () in let ef = new_e () in t,p_t,Base((ids,{t=Tfn(p_t,t,IP_none,ef)}),Emp_global,constraints,ef,nob),t_param_env in + let cond_kind = if (List.length funcls) = 1 then Solo else Switch in let check t_env tp_env imp_param = List.split (List.map (fun (FCL_aux((FCL_Funcl(id,pat,exp)),(l,_))) -> @@ -1738,7 +1740,7 @@ let check_fundef envs (FD_aux(FD_function(recopt,tannotopt,effectopt,funcls),(l, merge_bounds b_env b_env',tp_env)) imp_param ret_t exp in (*let _ = Printf.eprintf "checked function %s : %s -> %s\n" (id_to_string id) (t_to_string param_t) (t_to_string ret_t) in let _ = Printf.eprintf "constraints were pattern: %s\n expression: %s\n" (constraints_to_string cs_p) (constraints_to_string cs_e) in*) - let cs = [CondCons(Fun l,cs_p,cs_e)] in + let cs = [CondCons(Fun l,cond_kind,cs_p,cs_e)] in (FCL_aux((FCL_Funcl(id,pat',exp')),(l,(Base(([],ret_t),Emp_global,cs,ef,nob)))),(cs,ef))) funcls) in let update_pattern var (FCL_aux ((FCL_Funcl(id,(P_aux(pat,t)),exp)),annot)) = let pat' = match pat with diff --git a/src/type_internal.ml b/src/type_internal.ml index debd676e..7e4f2a2a 100644 --- a/src/type_internal.ml +++ b/src/type_internal.ml @@ -89,6 +89,7 @@ type constraint_origin = | Specc of Parse_ast.l type range_enforcement = Require | Guarantee +type cond_kind = Positive | Negative | Solo | Switch (* Constraints for nexps, plus the location which added the constraint *) type nexp_range = @@ -98,7 +99,7 @@ type nexp_range = | In of constraint_origin * string * int list | InS of constraint_origin * nexp * int list (* This holds the given value for string after a substitution *) | Predicate of constraint_origin * nexp_range (* This will treat the inner constraint as holding in positive condcons positions*) - | CondCons of constraint_origin * nexp_range list * nexp_range list + | CondCons of constraint_origin * cond_kind * nexp_range list * nexp_range list | BranchCons of constraint_origin * nexp_range list type variable_range = @@ -134,6 +135,13 @@ type def_envs = { default_o : order; } +type triple = Yes | No | Maybe +let triple_negate = function + | Yes -> No + | No -> Yes + | Maybe -> Maybe + + type exp = tannot Ast.exp (*Nexpression Makers (as built so often)*) @@ -249,6 +257,12 @@ let enforce_to_string = function | Require -> "require" | Guarantee -> "guarantee" +let cond_kind_to_string = function + | Positive -> "positive" + | Negative -> "negative" + | Solo -> "solo" + | Switch -> "switch" + let rec constraint_to_string = function | LtEq (co,enforce,nexp1,nexp2) -> "LtEq(" ^ co_to_string co ^ ", " ^ enforce_to_string enforce ^ ", " ^ n_to_string nexp1 ^ ", " ^ n_to_string nexp2 ^ ")" @@ -259,8 +273,9 @@ let rec constraint_to_string = function | In(co,var,ints) -> "In of " ^ var | InS(co,n,ints) -> "InS of " ^ n_to_string n | Predicate(co,cs) -> "Pred(" ^ co_to_string co ^ ", " ^ constraint_to_string cs ^ ")" - | CondCons(co,pats,exps) -> - "CondCons(" ^ co_to_string co ^ ", [" ^ constraints_to_string pats ^ "], [" ^ constraints_to_string exps ^ "])" + | CondCons(co,kind,pats,exps) -> + "CondCons(" ^ co_to_string co ^ ", " ^ cond_kind_to_string kind ^ + ", [" ^ constraints_to_string pats ^ "], [" ^ constraints_to_string exps ^ "])" | BranchCons(co,consts) -> "BranchCons(" ^ co_to_string co ^ ", [" ^ constraints_to_string consts ^ "])" and constraints_to_string l = string_of_list "; " constraint_to_string l @@ -454,7 +469,29 @@ let negate n = match n.nexp with | Nconst i -> mk_c (mult_int_big_int (-1) i) | _ -> mk_mult (mk_c_int (-1)) n -let rec normalize_nexp n = +let odd n = (n mod 2) = 1 + +(*Expects a normalized nexp*) +let rec nexp_negative n = + match n.nexp with + | Nconst i -> if lt_big_int i zero then Yes else No + | Nneg_inf -> Yes + | Npos_inf | N2n _ | Nvar _ | Nuvar _ -> No + | Nmult(n1,n2) -> (match nexp_negative n1, nexp_negative n2 with + | Yes,Yes | No, No -> No + | No, Yes | Yes, No -> Yes + | Maybe,_ | _, Maybe -> Maybe) + | Nadd(n1,n2) -> (match nexp_negative n1, nexp_negative n2 with + | Yes,Yes -> Yes + | No, No -> No + | _ -> Maybe) + | Npow(n1,i) -> + (match nexp_negative n1 with + | Yes -> if odd i then Yes else No + | No -> No + | Maybe -> if odd i then Maybe else No) + +let rec normalize_n_rec recur_ok n = (*let _ = Printf.eprintf "Working on normalizing %s\n" (n_to_string n) in *) match n.nexp with | Nconst _ | Nvar _ | Nuvar _ | Npos_inf | Nneg_inf | Ninexact -> n @@ -466,74 +503,92 @@ let rec normalize_nexp n = | Nneg n -> n,true,false | _ -> n,true,true) in if to_recur - then (let n' = normalize_nexp n' in + then (let n' = normalize_n_rec true n' in if add_neg then negate n' else n') else n' | Npow(n,i) -> - let n' = normalize_nexp n in + let n' = normalize_n_rec true n in (match n'.nexp with | Nconst n -> mk_c (pow_i i (int_of_big_int n)) | _ -> mk_pow n' i) | N2n(n', Some i) -> n (*Because there is a value for Some, we know this is normalized and n' is constant*) | N2n(n, None) -> - let n' = normalize_nexp n in + let n' = normalize_n_rec true n in (match n'.nexp with | Nconst i -> mk_2nc n' (two_pow (int_of_big_int i)) | _ -> mk_2n n') | Nadd(n1,n2) -> - let n1',n2' = normalize_nexp n1, normalize_nexp n2 in - (match n1'.nexp,n2'.nexp with - | Nneg_inf, Npos_inf | Npos_inf, Nneg_inf -> {nexp = Ninexact } - | Npos_inf, _ | _, Npos_inf -> { nexp = Npos_inf } - | Nneg_inf, _ | _, Nneg_inf -> { nexp = Nneg_inf } - | Nconst i1, Nconst i2 | Nconst i1, N2n(_,Some i2) | N2n(_,Some i2), Nconst i1 | N2n(_,Some i1),N2n(_,Some i2) + let n1',n2' = normalize_n_rec true n1, normalize_n_rec true n2 in + (match n1'.nexp,n2'.nexp, recur_ok with + | Nneg_inf, Npos_inf,_ | Npos_inf, Nneg_inf,_ -> {nexp = Ninexact } + | Npos_inf, _,_ | _, Npos_inf, _ -> { nexp = Npos_inf } + | Nneg_inf, _,_ | _, Nneg_inf, _ -> { nexp = Nneg_inf } + | Nconst i1, Nconst i2,_ | Nconst i1, N2n(_,Some i2),_ + | N2n(_,Some i2), Nconst i1,_ | N2n(_,Some i1),N2n(_,Some i2),_ -> mk_c (add_big_int i1 i2) - | Nadd(n11,n12), Nconst i -> - if (eq_big_int i zero) then n1' else mk_add n11 (normalize_nexp (mk_add n12 n2')) - | Nconst i, Nadd(n21,n22) -> - if (eq_big_int i zero) then n2' else mk_add n21 (normalize_nexp (mk_add n22 n1')) - | Nconst i, _ -> if (eq_big_int i zero) then n2' else mk_add n2' n1' - | _, Nconst i -> if (eq_big_int i zero) then n1' else mk_add n1' n2' - | Nvar _, Nuvar _ | Nvar _, N2n _ | Nuvar _, Npow _ | Nuvar _, N2n _ -> mk_add n2' n1' - | Nadd(n11,n12), Nadd(n21,n22) -> + | Nadd(n11,n12), Nconst i, true -> + if (eq_big_int i zero) then n1' + else normalize_n_rec false (mk_add n11 (normalize_n_rec false (mk_add n12 n2'))) + | Nadd(n11,n12), Nconst i, false -> + if (eq_big_int i zero) then n1' + else mk_add n11 (normalize_n_rec false (mk_add n12 n2')) + | Nconst i, Nadd(n21,n22), true -> + if (eq_big_int i zero) then n2' + else normalize_n_rec false (mk_add n21 (normalize_n_rec false (mk_add n22 n1'))) + | Nconst i, Nadd(n21,n22), false -> + if (eq_big_int i zero) then n2' + else mk_add n21 (normalize_n_rec false (mk_add n22 n1')) + | Nconst i, _,_ -> if (eq_big_int i zero) then n2' else mk_add n2' n1' + | _, Nconst i,_ -> if (eq_big_int i zero) then n1' else mk_add n1' n2' + | Nvar _, Nuvar _,_ | Nvar _, N2n _,_ | Nuvar _, Npow _,_ | Nuvar _, N2n _,_ -> mk_add n2' n1' + | Nadd(n11,n12), Nadd(n21,n22), true -> (match compare_nexps n11 n21 with - | -1 -> mk_add n11 (normalize_nexp (mk_add n12 n2')) + | -1 -> normalize_n_rec false (mk_add n11 (normalize_n_rec false (mk_add n12 n2'))) | 0 -> (match compare_nexps n12 n22 with - | -1 -> normalize_nexp (mk_add (mk_mult n_two n11) (mk_add n22 n12)) - | 0 -> normalize_nexp (mk_add (mk_mult n_two n11) (mk_mult n_two n12)) - | _ -> normalize_nexp (mk_add (mk_mult n_two n11) (mk_add n12 n22))) - | _ -> normalize_nexp (mk_add n21 (mk_add n22 n1'))) - | N2n(n11,_), N2n(n21,_) -> + | -1 -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_add n22 n12)) + | 0 -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_mult n_two n12)) + | _ -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_add n12 n22))) + | _ -> normalize_n_rec false (mk_add n21 (normalize_n_rec false (mk_add n22 n1')))) + | Nadd(n11,n12), Nadd(n21,n22), false -> + (match compare_nexps n11 n21 with + | -1 -> mk_add n11 (normalize_n_rec false (mk_add n12 n2')) + | 0 -> + (match compare_nexps n12 n22 with + | -1 -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_add n22 n12)) + | 0 -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_mult n_two n12)) + | _ -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_add n12 n22))) + | _ -> mk_add n21 (normalize_n_rec false (mk_add n22 n1'))) + | N2n(n11,_), N2n(n21,_),_ -> (match compare_nexps n11 n21 with | -1 -> mk_add n2' n1' - | 0 -> mk_2n (normalize_nexp (mk_add n11 n_one)) + | 0 -> mk_2n (normalize_n_rec true (mk_add n11 n_one)) | _ -> mk_add n1' n2') - | Npow(n11,i1), Npow (n21,i2) -> + | Npow(n11,i1), Npow (n21,i2),_ -> (match compare_nexps n11 n21, compare i1 i2 with | -1,-1 | 0,-1 -> mk_add n2' n1' | 0,0 -> mk_mult n_two n1' | _ -> mk_add n1' n2') - | N2n(n11,Some i),Nadd(n21,n22) -> - normalize_nexp (mk_add n21 (mk_add n22 (mk_c i))) - | Nadd(n11,n12), N2n(n21,Some i) -> - normalize_nexp (mk_add n11 (mk_add n12 (mk_c i))) - | N2n(n11,None),Nadd(n21,n22) -> + | N2n(n11,Some i),Nadd(n21,n22),_ -> + normalize_n_rec true (mk_add n21 (mk_add n22 (mk_c i))) + | Nadd(n11,n12), N2n(n21,Some i),_ -> + normalize_n_rec true (mk_add n11 (mk_add n12 (mk_c i))) + | N2n(n11,None),Nadd(n21,n22),_ -> (match n21.nexp with | N2n(n211,_) -> (match compare_nexps n11 n211 with | -1 -> mk_add n1' n2' - | 0 -> mk_add (mk_2n (normalize_nexp (mk_add n11 n_one))) n22 - | _ -> mk_add n21 (normalize_nexp (mk_add n11 n22))) + | 0 -> mk_add (mk_2n (normalize_n_rec true (mk_add n11 n_one))) n22 + | _ -> mk_add n21 (normalize_n_rec true (mk_add n11 n22))) | _ -> mk_add n1' n2') - | Nadd(n11,n12),N2n(n21,None) -> + | Nadd(n11,n12),N2n(n21,None),_ -> (match n11.nexp with | N2n(n111,_) -> (match compare_nexps n111 n21 with - | -1 -> mk_add n11 (normalize_nexp (mk_add n2' n12)) - | 0 -> mk_add (mk_2n (normalize_nexp (mk_add n111 n_one))) n12 + | -1 -> mk_add n11 (normalize_n_rec true (mk_add n2' n12)) + | 0 -> mk_add (mk_2n (normalize_n_rec true (mk_add n111 n_one))) n12 | _ -> mk_add n2' n1') | _ -> mk_add n2' n1') | _ -> @@ -549,22 +604,22 @@ let rec normalize_nexp n = | Nadd(n11',n12'), _ -> (match compare_nexps n11' n2' with | -1 -> mk_add n2' n1' - | 1 -> mk_add n11' (normalize_nexp (mk_add n12' n2')) + | 1 -> mk_add n11' (normalize_n_rec true (mk_add n12' n2')) | _ -> let _ = Printf.eprintf "Neither term has var but are the same? %s %s\n" (n_to_string n1') (n_to_string n2') in assert false) | (_, Nadd(n21',n22')) -> (match compare_nexps n1' n21' with - | -1 -> mk_add n21' (normalize_nexp (mk_add n1' n22')) + | -1 -> mk_add n21' (normalize_n_rec true (mk_add n1' n22')) | 1 -> mk_add n1' n2' | _ -> let _ = Printf.eprintf "pattern didn't match unexpextedly here %s %s\n" (n_to_string n1') (n_to_string n2') in assert false) | _ -> (match compare_nexps n1' n2' with | -1 -> mk_add n2' n1' - | 0 -> normalize_nexp (mk_mult n_two n1') + | 0 -> normalize_n_rec true (mk_mult n_two n1') | _ -> mk_add n1' n2')))) | Nsub(n1,n2) -> - let n1',n2' = normalize_nexp n1, normalize_nexp n2 in + let n1',n2' = normalize_n_rec true n1, normalize_n_rec true n2 in (*let _ = Printf.eprintf "Normalizing subtraction of %s - %s \n" (n_to_string n1') (n_to_string n2') in*) (match n1'.nexp,n2'.nexp with | Nneg_inf, Npos_inf | Npos_inf, Nneg_inf -> {nexp = Ninexact } @@ -575,41 +630,48 @@ let rec normalize_nexp n = mk_c (sub_big_int i1 i2) | Nconst i, _ -> if (eq_big_int i zero) - then normalize_nexp (negate n2') - else normalize_nexp (mk_add (negate n2') n1') + then normalize_n_rec true (negate n2') + else normalize_n_rec true (mk_add (negate n2') n1') | _, Nconst i -> if (eq_big_int i zero) then n1' - else normalize_nexp (mk_add n1' (mk_c (mult_int_big_int (-1) i))) + else normalize_n_rec true (mk_add n1' (mk_c (mult_int_big_int (-1) i))) | _,_ -> (match compare_nexps n1 n2 with | 0 -> n_zero | -1 -> mk_add (negate n2') n1' | _ -> mk_add n1' (negate n2'))) | Nmult(n1,n2) -> - let n1',n2' = normalize_nexp n1, normalize_nexp n2 in + let n1',n2' = normalize_n_rec true n1, normalize_n_rec true n2 in (match n1'.nexp,n2'.nexp with | Nneg_inf,Nneg_inf -> {nexp = Npos_inf} | Npos_inf, Nconst i | Nconst i, Npos_inf -> if eq_big_int i zero then n_zero else {nexp = Npos_inf} | Nneg_inf, Nconst i | Nconst i, Nneg_inf -> if eq_big_int i zero then n_zero else {nexp = Nneg_inf} - | Nneg_inf, _ | _, Nneg_inf -> {nexp = Nneg_inf} (*TODO write a is_negative predicate*) - | Npos_inf, _ | _, Npos_inf -> {nexp = Npos_inf} + | Nneg_inf, _ | _, Nneg_inf -> + (match nexp_negative n1, nexp_negative n2 with + | Yes, Yes -> {nexp = Npos_inf} + | _ -> {nexp = Nneg_inf}) + | Npos_inf, _ | _, Npos_inf -> + (match nexp_negative n1, nexp_negative n2 with + | Yes, Yes -> assert false (*One of them must be Npos_inf, so nexp_negative horribly broken*) + | No, Yes | Yes, No -> {nexp = Nneg_inf} + | _ -> {nexp = Npos_inf}) | Ninexact, _ | _, Ninexact -> {nexp = Ninexact} | Nconst i1, Nconst i2 -> mk_c (mult_big_int i1 i2) | Nconst i1, N2n(n,Some i2) | N2n(n,Some i2),Nconst i1 -> if eq_big_int i1 two - then mk_2nc (normalize_nexp (mk_add n n_one)) (mult_big_int i1 i2) + then mk_2nc (normalize_n_rec true (mk_add n n_one)) (mult_big_int i1 i2) else mk_c (mult_big_int i1 i2) | Nconst i1, N2n(n,None) | N2n(n,None),Nconst i1 -> if eq_big_int i1 two - then mk_2n (normalize_nexp (mk_add n n_one)) + then mk_2n (normalize_n_rec true (mk_add n n_one)) else mk_mult (mk_c i1) (mk_2n n) | (Nmult (_, _), (Nvar _|Npow (_, _)|Nuvar _)) -> mk_mult n1' n2' | Nvar _, Nuvar _ -> mk_mult n2' n1' - | N2n(n1,Some i1),N2n(n2,Some i2) -> mk_2nc (normalize_nexp (mk_add n1 n2)) (mult_big_int i1 i2) - | N2n(n1,_), N2n(n2,_) -> mk_2n (normalize_nexp (mk_add n1 n2)) + | N2n(n1,Some i1),N2n(n2,Some i2) -> mk_2nc (normalize_n_rec true (mk_add n1 n2)) (mult_big_int i1 i2) + | N2n(n1,_), N2n(n2,_) -> mk_2n (normalize_n_rec true (mk_add n1 n2)) | N2n _, Nvar _ | N2n _, Nuvar _ | N2n _, Nmult _ | Nuvar _, N2n _ -> mk_mult n2' n1' | Nuvar _, Nuvar _ | Nvar _, Nvar _ -> (match compare n1' n2' with @@ -623,10 +685,10 @@ let rec normalize_nexp n = | _ -> mk_mult n1' n2') | Nconst _, Nadd(n21,n22) | Nvar _,Nadd(n21,n22) | Nuvar _,Nadd(n21,n22) | N2n _, Nadd(n21,n22) | Npow _,Nadd(n21,n22) | Nmult _, Nadd(n21,n22) -> - normalize_nexp (mk_add (mk_mult n1' n21) (mk_mult n1' n21)) + normalize_n_rec true (mk_add (mk_mult n1' n21) (mk_mult n1' n21)) | Nadd(n11,n12),Nconst _ | Nadd(n11,n12),Nvar _ | Nadd(n11,n12), Nuvar _ | Nadd(n11,n12), N2n _ | Nadd(n11,n12),Npow _ | Nadd(n11,n12), Nmult _-> - normalize_nexp (mk_add (mk_mult n11 n2') (mk_mult n12 n2')) + normalize_n_rec true (mk_add (mk_mult n11 n2') (mk_mult n12 n2')) | Nmult(n11,n12), Nconst _ -> mk_mult (mk_mult n11 n2') (mk_mult n12 n2') | Nconst i1, _ -> if (eq_big_int i1 zero) then n1' @@ -637,9 +699,9 @@ let rec normalize_nexp n = else if (eq_big_int i1 one) then n1' else mk_mult n2' n1' | Nadd(n11,n12),Nadd(n21,n22) -> - normalize_nexp (mk_add (mk_mult n11 n21) + normalize_n_rec true (mk_add (mk_mult n11 n21) (mk_add (mk_mult n11 n22) - (mk_add (mk_mult n12 n21) (mk_mult n12 n22)))) + (mk_add (mk_mult n12 n21) (mk_mult n12 n22)))) | Nuvar _, Nvar _ | Nmult _, N2n _-> mk_mult n1' n2' | Nuvar _, Nmult(n1,n2) | Nvar _, Nmult(n1,n2) -> (match get_var n1, get_var n2 with @@ -648,8 +710,8 @@ let rec normalize_nexp n = | 0, Nuvar _ | 0, Nvar _ -> mk_mult n1 (mk_pow nv1 2) | 0, Npow(n2',i) -> mk_mult n1 (mk_pow n2' (i+1)) | -1, Nuvar _ | -1, Nvar _ -> mk_mult n2' n1' - | _,_ -> mk_mult (normalize_nexp (mk_mult n1 n1')) n2) - | _ -> mk_mult (normalize_nexp (mk_mult n1 n1')) n2) + | _,_ -> mk_mult (normalize_n_rec true (mk_mult n1 n1')) n2) + | _ -> mk_mult (normalize_n_rec true (mk_mult n1 n1')) n2) | (Npow (n1, i), (Nvar _ | Nuvar _)) -> (match compare_nexps n1 n2' with | 0 -> mk_pow n1 (i+1) @@ -662,9 +724,9 @@ let rec normalize_nexp n = (match compare_nexps nv1 nv2,n22.nexp with | 0, Nuvar _ | 0, Nvar _ -> mk_mult n21 (mk_pow n1 (i+1)) | 0, Npow(_,i2) -> mk_mult n21 (mk_pow n1 (i+i2)) - | 1,Npow _ -> mk_mult (normalize_nexp (mk_mult n21 n1')) n22 + | 1,Npow _ -> mk_mult (normalize_n_rec true (mk_mult n21 n1')) n22 | _ -> mk_mult n2' n1') - | _ -> mk_mult (normalize_nexp (mk_mult n1' n21)) n22) + | _ -> mk_mult (normalize_n_rec true (mk_mult n1' n21)) n22) | Nmult _ ,Nmult(n21,n22) -> mk_mult (mk_mult n21 n1') (mk_mult n22 n1') | Nsub _, _ | _, Nsub _ -> let _ = Printf.eprintf "nsub case still around %s\n" (n_to_string n) in assert false @@ -673,6 +735,8 @@ let rec normalize_nexp n = (* If things are normal, neg should be gone. *) ) +let normalize_nexp = normalize_n_rec true + let int_to_nexp = mk_c_int let v_count = ref 0 @@ -836,6 +900,62 @@ let nexp_one_more_than n1 n2 = if (int_of_big_int i) = 1 then nexp_eq n1' n2 else false | _ -> false + +let rec nexp_ge n1 n2 = + let n1,n2 = (normalize_nexp n1, normalize_nexp n2) in + if nexp_eq n1 n2 + then Yes + else + match n1.nexp,n2.nexp with + | Nconst i, Nconst j | N2n(_,Some i), N2n(_,Some j)-> if ge_big_int i j then Yes else No + | Npos_inf, _ | _, Nneg_inf | Nuvar _, Npos_inf | Nneg_inf, Nuvar _ -> Yes + | Nneg_inf, _ | _, Npos_inf -> No + | Ninexact, _ | _, Ninexact -> Maybe + | N2n(n1,_), N2n(n2,_) -> nexp_ge n1 n2 + | Nmult(n11,n12), Nmult(n21,n22) -> + if nexp_eq n12 n22 + then nexp_ge n11 n21 + else Maybe + | Nmult(n11,n12), _ -> + if nexp_eq n12 n2 + then triple_negate (nexp_negative n11) + else Maybe + | _, Nmult(n21,n22) -> + if nexp_eq n1 n22 + then nexp_negative n21 + else Maybe + | Nadd(n11,n12),Nadd(n21,n22) -> + (match (nexp_ge n11 n21, nexp_ge n12 n22, + (nexp_negative n11, nexp_negative n12, nexp_negative n21, nexp_negative n22)) with + | Yes, Yes, (No, No, No, No) -> Yes + | No, No, (No, No, No, No) -> No + | _ -> Maybe) + | Nadd(n11,n12), _ -> + if nexp_eq n11 n2 + then triple_negate (nexp_negative n12) + else if nexp_eq n12 n2 + then triple_negate (nexp_negative n11) + else Maybe + | _ , Nadd(n21,n22) -> + if nexp_eq n1 n21 + then nexp_negative n22 + else if nexp_eq n1 n22 + then nexp_negative n21 + else Maybe + | Npow(n11,i1), Npow(n21, i2) -> + if nexp_eq n11 n21 + then if i1 >= i2 then Yes else No + else Maybe + | Npow(n11,i1), _ -> + if nexp_eq n11 n2 + then if i1 = 0 then No else Yes + else Maybe + | _, Npow(n21,i2) -> + if nexp_eq n1 n21 + then if i2 = 0 then Yes else No + else Maybe + | _ -> Maybe + let equate_t (t_box : t) (t : t) : unit = let t = resolve_tsubst t in if t_box == t then () @@ -1014,12 +1134,12 @@ let rec contains_nuvar n cs = match cs with if (contains_nuvar_nexp n nl || contains_nuvar_nexp n nr) then co::(contains_nuvar n cs) else contains_nuvar n cs - | CondCons(so,conds,exps)::cs -> + | CondCons(so,kind,conds,exps)::cs -> let conds' = contains_nuvar n conds in let exps' = contains_nuvar n exps in (match conds',exps' with | [],[] -> contains_nuvar n cs - | _ -> CondCons(so,conds',exps')::contains_nuvar n cs) + | _ -> CondCons(so,kind,conds',exps')::contains_nuvar n cs) | BranchCons(so,b_cs)::cs -> (match contains_nuvar n b_cs with | [] -> contains_nuvar n cs @@ -1852,7 +1972,7 @@ let rec cs_subst t_env cs = InS(l,nexp,ns)::(cs_subst t_env cs) | InS(l,n,ns)::cs -> InS(l,n_subst t_env n,ns)::(cs_subst t_env cs) | Predicate(l, c)::cs -> Predicate(l, List.hd(cs_subst t_env [c]))::(cs_subst t_env cs) - | CondCons(l,cs_p,cs_e)::cs -> CondCons(l,cs_subst t_env cs_p,cs_subst t_env cs_e)::(cs_subst t_env cs) + | CondCons(l,kind,cs_p,cs_e)::cs -> CondCons(l,kind,cs_subst t_env cs_p,cs_subst t_env cs_e)::(cs_subst t_env cs) | BranchCons(l,bs)::cs -> BranchCons(l,cs_subst t_env bs)::(cs_subst t_env cs) let subst (k_env : (Envmap.k * kind) list) (leave_imp:bool) @@ -2257,7 +2377,7 @@ let rec type_consistent_internal co d_env enforce widen t1 cs1 t2 cs2 = | Tapp("range",[TA_nexp b1;TA_nexp r1;]),Tapp("range",[TA_nexp b2;TA_nexp r2;]) -> if (nexp_eq b1 b2)&&(nexp_eq r1 r2) then (t2,csp) - else (t1, csp@[LtEq(co,enforce,b1,b2);LtEq(co,enforce,r1,r2)]) + else (t1, csp@[GtEq(co,enforce,b1,b2);LtEq(co,enforce,r1,r2)]) | Tapp("atom",[TA_nexp a]),Tapp("range",[TA_nexp b1; TA_nexp r1]) -> (t1, csp@[GtEq(co,enforce,a,b1);LtEq(co,enforce,a,r1)]) | Tapp("range",[TA_nexp b1; TA_nexp r1]),Tapp("atom",[TA_nexp a]) -> @@ -2630,7 +2750,7 @@ let rec get_all_nuvars_cs cs = match cs with let n1s = get_nuvars n1 in let n2s = get_nuvars n2 in List.fold_right (fun n s -> Var_set.add n s) (n1s@n2s) s - | CondCons(_,pats,exps)::cs -> + | CondCons(_,_,pats,exps)::cs -> let s = get_all_nuvars_cs cs in let ps = get_all_nuvars_cs pats in let es = get_all_nuvars_cs exps in @@ -2661,12 +2781,16 @@ let rec equate_nuvars in_env cs = if (equate_n n1 n2) then equate cs else c::equate cs end else c::equate cs | _ -> c::equate cs) - | CondCons(co,pats,exps):: cs -> - let pats' = equate pats in - let exps' = equate exps in - (match pats',exps' with - | [],[] -> equate cs - | _ -> CondCons(co,pats',exps')::(equate cs)) + | (CondCons(co,kind,pats,exps) as c):: cs -> + (match kind with + | Solo | Positive | Negative (*Wrong, but I'd like to get things working again*)-> + (*let _ = Printf.eprintf "equate_nuvars: condcons: %s\n%!" (constraints_to_string [c]) in*) + let pats' = equate pats in + let exps' = equate exps in + (match pats',exps' with + | [],[] -> equate cs + | _ -> CondCons(co,kind,pats',exps')::(equate cs)) + | _ -> CondCons(co,kind,pats,exps)::(equate cs)) | BranchCons(co,branches)::cs -> let b' = equate branches in if [] = b' @@ -2763,14 +2887,18 @@ let rec simple_constraint_check in_env cs = if equate_n n2' n1' then check cs else (GtEq(co,enforce,n1',n2')::check cs) | Nneg_inf, _ -> eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires negative infinity to be greater than or equal to " ^ (n_to_string n2')) | _,_ -> - let new_n = normalize_nexp (mk_sub n1' n2') in - (match new_n.nexp with - | Nconst i -> - if ge_big_int i zero - then (check cs) - else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires " - ^ n_to_string new_n ^ " to be greater than or equal to 0, not " ^ string_of_big_int i) - | _ -> GtEq(co,enforce,n1',n2')::(check cs))) + (match nexp_ge n1' n2' with + | Yes -> check cs + | No -> eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires " ^ n_to_string n1 ^ " to be greather than or equal to " ^ (n_to_string n2)) + | Maybe -> + let new_n = normalize_nexp (mk_sub n1' n2') in + (match new_n.nexp with + | Nconst i -> + if ge_big_int i zero + then (check cs) + else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires " + ^ n_to_string new_n ^ " to be greater than or equal to 0, not " ^ string_of_big_int i) + | _ -> GtEq(co,enforce,n1',n2')::(check cs)))) | LtEq(co,enforce,n1,n2)::cs -> (*let _ = Printf.eprintf "<= check, about to normalize_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *) let n1',n2' = normalize_nexp n1,normalize_nexp n2 in @@ -2784,15 +2912,24 @@ let rec simple_constraint_check in_env cs = | _, Npos_inf | Nneg_inf, _ -> check cs (*| Npos_inf, Nconst _ -> eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires infinity to be less than or equal to " ^ (n_to_string n2'))*) - | _,_ -> LtEq(co,enforce,n1',n2')::(check cs)) - | CondCons(co,pats,exps):: cs -> + | _,_ -> + (match nexp_ge n2' n1' with + | Yes -> check cs + | No -> eq_error (get_c_loc co) + ("Type constraint mismatch: constraint arising from here requires " ^ n_to_string n1 ^ + " to be less than or equal to " ^ (n_to_string n2)) + | Maybe -> LtEq(co,enforce,n1',n2')::(check cs))) + | CondCons(co,kind,pats,exps):: cs -> + (match kind with + | Solo | Positive | Negative (*Wrong, but let's get things working again*)-> (*let _ = Printf.eprintf "Condcons check length pats %i, length exps %i\n" (List.length pats) (List.length exps) in*) - let pats' = check pats in - let exps' = check exps in + let pats' = check pats in + let exps' = check exps in (*let _ = Printf.eprintf "Condcons after check length pats' %i, length exps' %i\n" (List.length pats') (List.length exps') in*) - (match pats',exps' with - | [],[] -> check cs - | _ -> CondCons(co,pats',exps')::(check cs)) + (match pats',exps' with + | [],[] -> check cs + | _ -> CondCons(co,kind,pats',exps')::(check cs)) + | _ -> CondCons(co,kind,pats,exps)::(check cs)) | BranchCons(co,branches)::cs -> (*let _ = Printf.eprintf "Branchcons check length branches %i\n" (List.length branches) in*) let b' = check branches in @@ -2825,7 +2962,7 @@ let rec constraint_size = function | [] -> 0 | c::cs -> (match c with - | CondCons(_,ps,es) -> constraint_size ps + constraint_size es + | CondCons(_,_,ps,es) -> constraint_size ps + constraint_size es | BranchCons(_,bs) -> constraint_size bs | _ -> 1) + constraint_size cs diff --git a/src/type_internal.mli b/src/type_internal.mli index af2b99e0..a5457546 100644 --- a/src/type_internal.mli +++ b/src/type_internal.mli @@ -89,6 +89,7 @@ type constraint_origin = | Specc of Parse_ast.l type range_enforcement = Require | Guarantee +type cond_kind = Positive | Negative | Solo | Switch (* Constraints for nexps, plus the location which added the constraint *) type nexp_range = @@ -98,7 +99,7 @@ type nexp_range = | In of constraint_origin * string * int list | InS of constraint_origin * nexp * int list (* This holds the given value for string after a substitution *) | Predicate of constraint_origin * nexp_range (* This will treat the inner constraint as holding in positive condcons positions : must be one of LtEq, Eq, or GtEq*) - | CondCons of constraint_origin * nexp_range list * nexp_range list (* Constraints from one path from a conditional (pattern or if) and the constraints from that conditional *) + | CondCons of constraint_origin * cond_kind * nexp_range list * nexp_range list (* Constraints from one path from a conditional (pattern or if) and the constraints from that conditional *) | BranchCons of constraint_origin * nexp_range list (* CondCons constraints from all branches of a conditional; list should be all CondCons *) val get_c_loc : constraint_origin -> Parse_ast.l |
