summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKathy Gray2015-08-14 10:30:30 +0100
committerKathy Gray2015-08-14 10:30:30 +0100
commitd698593f14334811f3230d385737cc2bc96b5a63 (patch)
treee78280127bc9e1736c3625168063ef3f3a5d8bd4 /src
parentd4d2e262f96a8eef543c017c8df08c25f2715118 (diff)
Steps towards making constraint solver smarter
Diffstat (limited to 'src')
-rw-r--r--src/pretty_print.ml6
-rw-r--r--src/type_check.ml22
-rw-r--r--src/type_internal.ml313
-rw-r--r--src/type_internal.mli3
4 files changed, 242 insertions, 102 deletions
diff --git a/src/pretty_print.ml b/src/pretty_print.ml
index 266e8141..182c1929 100644
--- a/src/pretty_print.ml
+++ b/src/pretty_print.ml
@@ -65,14 +65,14 @@ let pp_format_var (Kid_aux(Var v,_)) = v
let rec pp_format_l_lem = function
| Parse_ast.Unknown -> "Unknown"
-(* | _ -> "Unknown"*)
- | Parse_ast.Int(s,None) -> "(Int \"" ^ s ^ "\" Nothing)"
+ | _ -> "Unknown"
+(* | Parse_ast.Int(s,None) -> "(Int \"" ^ s ^ "\" Nothing)"
| Parse_ast.Int(s,(Some l)) -> "(Int \"" ^ s ^ "\" (Just " ^ (pp_format_l_lem l) ^ "))"
| Parse_ast.Range(p1,p2) -> "(Range \"" ^ p1.Lexing.pos_fname ^ "\" " ^
(string_of_int p1.Lexing.pos_lnum) ^ " " ^
(string_of_int (p1.Lexing.pos_cnum - p1.Lexing.pos_bol)) ^ " " ^
(string_of_int p2.Lexing.pos_lnum) ^ " " ^
- (string_of_int (p2.Lexing.pos_cnum - p2.Lexing.pos_bol)) ^ ")"
+ (string_of_int (p2.Lexing.pos_cnum - p2.Lexing.pos_bol)) ^ ")"*)
let pp_lem_l ppf l = base ppf (pp_format_l_lem l)
diff --git a/src/type_check.ml b/src/type_check.ml
index 8ee47e10..6b3ebbb1 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -818,8 +818,8 @@ let rec check_exp envs (imp_param:nexp option) (expect_t:t) (E_aux(e,(l,annot)):
(*TOTHINK Possibly I should first consistency check else and then, with Guarantee, then check against expect_t with Require*)
let then_t',then_c' = type_consistent (Expr l) d_env Require true then_t expect_t in
let else_t',else_c' = type_consistent (Expr l) d_env Require true else_t expect_t in
- let t_cs = CondCons((Expr l),c1,then_c@then_c') in
- let e_cs = CondCons((Expr l),[],else_c@else_c') in
+ let t_cs = CondCons((Expr l),Positive,c1,then_c@then_c') in
+ let e_cs = CondCons((Expr l),Negative,[],else_c@else_c') in
(E_aux(E_if(cond',then',else'),(l,simple_annot expect_t)),
expect_t,Envmap.intersect_merge (tannot_merge (Expr l) d_env true) then_env else_env,[t_cs;e_cs],
merge_bounds then_bs else_bs, (*TODO Should be an intersecting merge*)
@@ -827,8 +827,8 @@ let rec check_exp envs (imp_param:nexp option) (expect_t:t) (E_aux(e,(l,annot)):
| _ ->
let then',then_t,then_env,then_c,then_bs,then_ef = check_exp envs imp_param expect_t then_ in
let else',else_t,else_env,else_c,else_bs,else_ef = check_exp envs imp_param expect_t else_ in
- let t_cs = CondCons((Expr l),c1,then_c) in
- let e_cs = CondCons((Expr l),[],else_c) in
+ let t_cs = CondCons((Expr l),Positive,c1,then_c) in
+ let e_cs = CondCons((Expr l),Negative,[],else_c) in
(E_aux(E_if(cond',then',else'),(l,simple_annot expect_t)),
expect_t,Envmap.intersect_merge (tannot_merge (Expr l) d_env true) then_env else_env,[t_cs;e_cs],
merge_bounds then_bs else_bs,
@@ -1225,7 +1225,8 @@ let rec check_exp envs (imp_param:nexp option) (expect_t:t) (E_aux(e,(l,annot)):
| Tapp("register",[TA_typ t]) -> t
| _ -> t' in
(*let _ = Printf.eprintf "Type of pattern after register check %s\n" (t_to_string t') in*)
- let (pexps',t,cs',ef') = check_cases envs imp_param t' expect_t pexps in
+ let (pexps',t,cs',ef') =
+ check_cases envs imp_param t' expect_t (if (List.length pexps) = 1 then Solo else Switch) pexps in
(E_aux(E_case(e',pexps'),(l,simple_annot t)),t,t_env,cs@cs',nob,union_effects ef ef')
| E_let(lbind,body) ->
let (lb',t_env',cs,b_env',ef) = (check_lbind envs imp_param false Emp_local lbind) in
@@ -1260,7 +1261,7 @@ and check_block envs imp_param expect_t exps:((tannot exp) list * tannot * nexp_
merge_bounds b_env' b_env, tp_env)) imp_param expect_t exps in
((e'::exps'),annot',sc@sc',t,union_effects ef ef')
-and check_cases envs imp_param check_t expect_t pexps : ((tannot pexp) list * typ * nexp_range list * effect) =
+and check_cases envs imp_param check_t expect_t kind pexps : ((tannot pexp) list * typ * nexp_range list * effect) =
let (Env(d_env,t_env,b_env,tp_env)) = envs in
match pexps with
| [] -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "switch with no cases found")
@@ -1270,7 +1271,7 @@ and check_cases envs imp_param check_t expect_t pexps : ((tannot pexp) list * ty
check_exp (Env(d_env,
Envmap.union_merge (tannot_merge (Expr l) d_env true) t_env env,
merge_bounds b_env bounds, tp_env)) imp_param expect_t exp in
- let cs = [CondCons(Expr l, cs_p, cs_e)] in
+ let cs = [CondCons(Expr l,kind, cs_p, cs_e)] in
[Pat_aux(Pat_exp(pat',e),(l,cons_ef_annot t cs ef))],t,cs,ef
| ((Pat_aux(Pat_exp(pat,exp),(l,annot)))::pexps) ->
let pat',env,cs_p,bounds,u = check_pattern envs Emp_local check_t pat in
@@ -1278,8 +1279,8 @@ and check_cases envs imp_param check_t expect_t pexps : ((tannot pexp) list * ty
check_exp (Env(d_env,
Envmap.union_merge (tannot_merge (Expr l) d_env true) t_env env,
merge_bounds b_env bounds, tp_env)) imp_param expect_t exp in
- let cs = CondCons(Expr l,cs_p,cs_e) in
- let (pes,tl,csl,efl) = check_cases envs imp_param check_t expect_t pexps in
+ let cs = CondCons(Expr l,kind,cs_p,cs_e) in
+ let (pes,tl,csl,efl) = check_cases envs imp_param check_t expect_t kind pexps in
((Pat_aux(Pat_exp(pat',e),(l,cons_ef_annot t [cs] ef)))::pes,tl,cs::csl,union_effects efl ef)
and check_lexp envs imp_param is_top (LEXP_aux(lexp,(l,annot)))
@@ -1727,6 +1728,7 @@ let check_fundef envs (FD_aux(FD_function(recopt,tannotopt,effectopt,funcls),(l,
let p_t = new_t () in
let ef = new_e () in
t,p_t,Base((ids,{t=Tfn(p_t,t,IP_none,ef)}),Emp_global,constraints,ef,nob),t_param_env in
+ let cond_kind = if (List.length funcls) = 1 then Solo else Switch in
let check t_env tp_env imp_param =
List.split
(List.map (fun (FCL_aux((FCL_Funcl(id,pat,exp)),(l,_))) ->
@@ -1738,7 +1740,7 @@ let check_fundef envs (FD_aux(FD_function(recopt,tannotopt,effectopt,funcls),(l,
merge_bounds b_env b_env',tp_env)) imp_param ret_t exp in
(*let _ = Printf.eprintf "checked function %s : %s -> %s\n" (id_to_string id) (t_to_string param_t) (t_to_string ret_t) in
let _ = Printf.eprintf "constraints were pattern: %s\n expression: %s\n" (constraints_to_string cs_p) (constraints_to_string cs_e) in*)
- let cs = [CondCons(Fun l,cs_p,cs_e)] in
+ let cs = [CondCons(Fun l,cond_kind,cs_p,cs_e)] in
(FCL_aux((FCL_Funcl(id,pat',exp')),(l,(Base(([],ret_t),Emp_global,cs,ef,nob)))),(cs,ef))) funcls) in
let update_pattern var (FCL_aux ((FCL_Funcl(id,(P_aux(pat,t)),exp)),annot)) =
let pat' = match pat with
diff --git a/src/type_internal.ml b/src/type_internal.ml
index debd676e..7e4f2a2a 100644
--- a/src/type_internal.ml
+++ b/src/type_internal.ml
@@ -89,6 +89,7 @@ type constraint_origin =
| Specc of Parse_ast.l
type range_enforcement = Require | Guarantee
+type cond_kind = Positive | Negative | Solo | Switch
(* Constraints for nexps, plus the location which added the constraint *)
type nexp_range =
@@ -98,7 +99,7 @@ type nexp_range =
| In of constraint_origin * string * int list
| InS of constraint_origin * nexp * int list (* This holds the given value for string after a substitution *)
| Predicate of constraint_origin * nexp_range (* This will treat the inner constraint as holding in positive condcons positions*)
- | CondCons of constraint_origin * nexp_range list * nexp_range list
+ | CondCons of constraint_origin * cond_kind * nexp_range list * nexp_range list
| BranchCons of constraint_origin * nexp_range list
type variable_range =
@@ -134,6 +135,13 @@ type def_envs = {
default_o : order;
}
+type triple = Yes | No | Maybe
+let triple_negate = function
+ | Yes -> No
+ | No -> Yes
+ | Maybe -> Maybe
+
+
type exp = tannot Ast.exp
(*Nexpression Makers (as built so often)*)
@@ -249,6 +257,12 @@ let enforce_to_string = function
| Require -> "require"
| Guarantee -> "guarantee"
+let cond_kind_to_string = function
+ | Positive -> "positive"
+ | Negative -> "negative"
+ | Solo -> "solo"
+ | Switch -> "switch"
+
let rec constraint_to_string = function
| LtEq (co,enforce,nexp1,nexp2) ->
"LtEq(" ^ co_to_string co ^ ", " ^ enforce_to_string enforce ^ ", " ^ n_to_string nexp1 ^ ", " ^ n_to_string nexp2 ^ ")"
@@ -259,8 +273,9 @@ let rec constraint_to_string = function
| In(co,var,ints) -> "In of " ^ var
| InS(co,n,ints) -> "InS of " ^ n_to_string n
| Predicate(co,cs) -> "Pred(" ^ co_to_string co ^ ", " ^ constraint_to_string cs ^ ")"
- | CondCons(co,pats,exps) ->
- "CondCons(" ^ co_to_string co ^ ", [" ^ constraints_to_string pats ^ "], [" ^ constraints_to_string exps ^ "])"
+ | CondCons(co,kind,pats,exps) ->
+ "CondCons(" ^ co_to_string co ^ ", " ^ cond_kind_to_string kind ^
+ ", [" ^ constraints_to_string pats ^ "], [" ^ constraints_to_string exps ^ "])"
| BranchCons(co,consts) ->
"BranchCons(" ^ co_to_string co ^ ", [" ^ constraints_to_string consts ^ "])"
and constraints_to_string l = string_of_list "; " constraint_to_string l
@@ -454,7 +469,29 @@ let negate n = match n.nexp with
| Nconst i -> mk_c (mult_int_big_int (-1) i)
| _ -> mk_mult (mk_c_int (-1)) n
-let rec normalize_nexp n =
+let odd n = (n mod 2) = 1
+
+(*Expects a normalized nexp*)
+let rec nexp_negative n =
+ match n.nexp with
+ | Nconst i -> if lt_big_int i zero then Yes else No
+ | Nneg_inf -> Yes
+ | Npos_inf | N2n _ | Nvar _ | Nuvar _ -> No
+ | Nmult(n1,n2) -> (match nexp_negative n1, nexp_negative n2 with
+ | Yes,Yes | No, No -> No
+ | No, Yes | Yes, No -> Yes
+ | Maybe,_ | _, Maybe -> Maybe)
+ | Nadd(n1,n2) -> (match nexp_negative n1, nexp_negative n2 with
+ | Yes,Yes -> Yes
+ | No, No -> No
+ | _ -> Maybe)
+ | Npow(n1,i) ->
+ (match nexp_negative n1 with
+ | Yes -> if odd i then Yes else No
+ | No -> No
+ | Maybe -> if odd i then Maybe else No)
+
+let rec normalize_n_rec recur_ok n =
(*let _ = Printf.eprintf "Working on normalizing %s\n" (n_to_string n) in *)
match n.nexp with
| Nconst _ | Nvar _ | Nuvar _ | Npos_inf | Nneg_inf | Ninexact -> n
@@ -466,74 +503,92 @@ let rec normalize_nexp n =
| Nneg n -> n,true,false
| _ -> n,true,true) in
if to_recur
- then (let n' = normalize_nexp n' in
+ then (let n' = normalize_n_rec true n' in
if add_neg
then negate n'
else n')
else n'
| Npow(n,i) ->
- let n' = normalize_nexp n in
+ let n' = normalize_n_rec true n in
(match n'.nexp with
| Nconst n -> mk_c (pow_i i (int_of_big_int n))
| _ -> mk_pow n' i)
| N2n(n', Some i) -> n (*Because there is a value for Some, we know this is normalized and n' is constant*)
| N2n(n, None) ->
- let n' = normalize_nexp n in
+ let n' = normalize_n_rec true n in
(match n'.nexp with
| Nconst i -> mk_2nc n' (two_pow (int_of_big_int i))
| _ -> mk_2n n')
| Nadd(n1,n2) ->
- let n1',n2' = normalize_nexp n1, normalize_nexp n2 in
- (match n1'.nexp,n2'.nexp with
- | Nneg_inf, Npos_inf | Npos_inf, Nneg_inf -> {nexp = Ninexact }
- | Npos_inf, _ | _, Npos_inf -> { nexp = Npos_inf }
- | Nneg_inf, _ | _, Nneg_inf -> { nexp = Nneg_inf }
- | Nconst i1, Nconst i2 | Nconst i1, N2n(_,Some i2) | N2n(_,Some i2), Nconst i1 | N2n(_,Some i1),N2n(_,Some i2)
+ let n1',n2' = normalize_n_rec true n1, normalize_n_rec true n2 in
+ (match n1'.nexp,n2'.nexp, recur_ok with
+ | Nneg_inf, Npos_inf,_ | Npos_inf, Nneg_inf,_ -> {nexp = Ninexact }
+ | Npos_inf, _,_ | _, Npos_inf, _ -> { nexp = Npos_inf }
+ | Nneg_inf, _,_ | _, Nneg_inf, _ -> { nexp = Nneg_inf }
+ | Nconst i1, Nconst i2,_ | Nconst i1, N2n(_,Some i2),_
+ | N2n(_,Some i2), Nconst i1,_ | N2n(_,Some i1),N2n(_,Some i2),_
-> mk_c (add_big_int i1 i2)
- | Nadd(n11,n12), Nconst i ->
- if (eq_big_int i zero) then n1' else mk_add n11 (normalize_nexp (mk_add n12 n2'))
- | Nconst i, Nadd(n21,n22) ->
- if (eq_big_int i zero) then n2' else mk_add n21 (normalize_nexp (mk_add n22 n1'))
- | Nconst i, _ -> if (eq_big_int i zero) then n2' else mk_add n2' n1'
- | _, Nconst i -> if (eq_big_int i zero) then n1' else mk_add n1' n2'
- | Nvar _, Nuvar _ | Nvar _, N2n _ | Nuvar _, Npow _ | Nuvar _, N2n _ -> mk_add n2' n1'
- | Nadd(n11,n12), Nadd(n21,n22) ->
+ | Nadd(n11,n12), Nconst i, true ->
+ if (eq_big_int i zero) then n1'
+ else normalize_n_rec false (mk_add n11 (normalize_n_rec false (mk_add n12 n2')))
+ | Nadd(n11,n12), Nconst i, false ->
+ if (eq_big_int i zero) then n1'
+ else mk_add n11 (normalize_n_rec false (mk_add n12 n2'))
+ | Nconst i, Nadd(n21,n22), true ->
+ if (eq_big_int i zero) then n2'
+ else normalize_n_rec false (mk_add n21 (normalize_n_rec false (mk_add n22 n1')))
+ | Nconst i, Nadd(n21,n22), false ->
+ if (eq_big_int i zero) then n2'
+ else mk_add n21 (normalize_n_rec false (mk_add n22 n1'))
+ | Nconst i, _,_ -> if (eq_big_int i zero) then n2' else mk_add n2' n1'
+ | _, Nconst i,_ -> if (eq_big_int i zero) then n1' else mk_add n1' n2'
+ | Nvar _, Nuvar _,_ | Nvar _, N2n _,_ | Nuvar _, Npow _,_ | Nuvar _, N2n _,_ -> mk_add n2' n1'
+ | Nadd(n11,n12), Nadd(n21,n22), true ->
(match compare_nexps n11 n21 with
- | -1 -> mk_add n11 (normalize_nexp (mk_add n12 n2'))
+ | -1 -> normalize_n_rec false (mk_add n11 (normalize_n_rec false (mk_add n12 n2')))
| 0 ->
(match compare_nexps n12 n22 with
- | -1 -> normalize_nexp (mk_add (mk_mult n_two n11) (mk_add n22 n12))
- | 0 -> normalize_nexp (mk_add (mk_mult n_two n11) (mk_mult n_two n12))
- | _ -> normalize_nexp (mk_add (mk_mult n_two n11) (mk_add n12 n22)))
- | _ -> normalize_nexp (mk_add n21 (mk_add n22 n1')))
- | N2n(n11,_), N2n(n21,_) ->
+ | -1 -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_add n22 n12))
+ | 0 -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_mult n_two n12))
+ | _ -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_add n12 n22)))
+ | _ -> normalize_n_rec false (mk_add n21 (normalize_n_rec false (mk_add n22 n1'))))
+ | Nadd(n11,n12), Nadd(n21,n22), false ->
+ (match compare_nexps n11 n21 with
+ | -1 -> mk_add n11 (normalize_n_rec false (mk_add n12 n2'))
+ | 0 ->
+ (match compare_nexps n12 n22 with
+ | -1 -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_add n22 n12))
+ | 0 -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_mult n_two n12))
+ | _ -> normalize_n_rec true (mk_add (mk_mult n_two n11) (mk_add n12 n22)))
+ | _ -> mk_add n21 (normalize_n_rec false (mk_add n22 n1')))
+ | N2n(n11,_), N2n(n21,_),_ ->
(match compare_nexps n11 n21 with
| -1 -> mk_add n2' n1'
- | 0 -> mk_2n (normalize_nexp (mk_add n11 n_one))
+ | 0 -> mk_2n (normalize_n_rec true (mk_add n11 n_one))
| _ -> mk_add n1' n2')
- | Npow(n11,i1), Npow (n21,i2) ->
+ | Npow(n11,i1), Npow (n21,i2),_ ->
(match compare_nexps n11 n21, compare i1 i2 with
| -1,-1 | 0,-1 -> mk_add n2' n1'
| 0,0 -> mk_mult n_two n1'
| _ -> mk_add n1' n2')
- | N2n(n11,Some i),Nadd(n21,n22) ->
- normalize_nexp (mk_add n21 (mk_add n22 (mk_c i)))
- | Nadd(n11,n12), N2n(n21,Some i) ->
- normalize_nexp (mk_add n11 (mk_add n12 (mk_c i)))
- | N2n(n11,None),Nadd(n21,n22) ->
+ | N2n(n11,Some i),Nadd(n21,n22),_ ->
+ normalize_n_rec true (mk_add n21 (mk_add n22 (mk_c i)))
+ | Nadd(n11,n12), N2n(n21,Some i),_ ->
+ normalize_n_rec true (mk_add n11 (mk_add n12 (mk_c i)))
+ | N2n(n11,None),Nadd(n21,n22),_ ->
(match n21.nexp with
| N2n(n211,_) ->
(match compare_nexps n11 n211 with
| -1 -> mk_add n1' n2'
- | 0 -> mk_add (mk_2n (normalize_nexp (mk_add n11 n_one))) n22
- | _ -> mk_add n21 (normalize_nexp (mk_add n11 n22)))
+ | 0 -> mk_add (mk_2n (normalize_n_rec true (mk_add n11 n_one))) n22
+ | _ -> mk_add n21 (normalize_n_rec true (mk_add n11 n22)))
| _ -> mk_add n1' n2')
- | Nadd(n11,n12),N2n(n21,None) ->
+ | Nadd(n11,n12),N2n(n21,None),_ ->
(match n11.nexp with
| N2n(n111,_) ->
(match compare_nexps n111 n21 with
- | -1 -> mk_add n11 (normalize_nexp (mk_add n2' n12))
- | 0 -> mk_add (mk_2n (normalize_nexp (mk_add n111 n_one))) n12
+ | -1 -> mk_add n11 (normalize_n_rec true (mk_add n2' n12))
+ | 0 -> mk_add (mk_2n (normalize_n_rec true (mk_add n111 n_one))) n12
| _ -> mk_add n2' n1')
| _ -> mk_add n2' n1')
| _ ->
@@ -549,22 +604,22 @@ let rec normalize_nexp n =
| Nadd(n11',n12'), _ ->
(match compare_nexps n11' n2' with
| -1 -> mk_add n2' n1'
- | 1 -> mk_add n11' (normalize_nexp (mk_add n12' n2'))
+ | 1 -> mk_add n11' (normalize_n_rec true (mk_add n12' n2'))
| _ -> let _ = Printf.eprintf "Neither term has var but are the same? %s %s\n"
(n_to_string n1') (n_to_string n2') in assert false)
| (_, Nadd(n21',n22')) ->
(match compare_nexps n1' n21' with
- | -1 -> mk_add n21' (normalize_nexp (mk_add n1' n22'))
+ | -1 -> mk_add n21' (normalize_n_rec true (mk_add n1' n22'))
| 1 -> mk_add n1' n2'
| _ -> let _ = Printf.eprintf "pattern didn't match unexpextedly here %s %s\n"
(n_to_string n1') (n_to_string n2') in assert false)
| _ ->
(match compare_nexps n1' n2' with
| -1 -> mk_add n2' n1'
- | 0 -> normalize_nexp (mk_mult n_two n1')
+ | 0 -> normalize_n_rec true (mk_mult n_two n1')
| _ -> mk_add n1' n2'))))
| Nsub(n1,n2) ->
- let n1',n2' = normalize_nexp n1, normalize_nexp n2 in
+ let n1',n2' = normalize_n_rec true n1, normalize_n_rec true n2 in
(*let _ = Printf.eprintf "Normalizing subtraction of %s - %s \n" (n_to_string n1') (n_to_string n2') in*)
(match n1'.nexp,n2'.nexp with
| Nneg_inf, Npos_inf | Npos_inf, Nneg_inf -> {nexp = Ninexact }
@@ -575,41 +630,48 @@ let rec normalize_nexp n =
mk_c (sub_big_int i1 i2)
| Nconst i, _ ->
if (eq_big_int i zero)
- then normalize_nexp (negate n2')
- else normalize_nexp (mk_add (negate n2') n1')
+ then normalize_n_rec true (negate n2')
+ else normalize_n_rec true (mk_add (negate n2') n1')
| _, Nconst i ->
if (eq_big_int i zero)
then n1'
- else normalize_nexp (mk_add n1' (mk_c (mult_int_big_int (-1) i)))
+ else normalize_n_rec true (mk_add n1' (mk_c (mult_int_big_int (-1) i)))
| _,_ ->
(match compare_nexps n1 n2 with
| 0 -> n_zero
| -1 -> mk_add (negate n2') n1'
| _ -> mk_add n1' (negate n2')))
| Nmult(n1,n2) ->
- let n1',n2' = normalize_nexp n1, normalize_nexp n2 in
+ let n1',n2' = normalize_n_rec true n1, normalize_n_rec true n2 in
(match n1'.nexp,n2'.nexp with
| Nneg_inf,Nneg_inf -> {nexp = Npos_inf}
| Npos_inf, Nconst i | Nconst i, Npos_inf ->
if eq_big_int i zero then n_zero else {nexp = Npos_inf}
| Nneg_inf, Nconst i | Nconst i, Nneg_inf ->
if eq_big_int i zero then n_zero else {nexp = Nneg_inf}
- | Nneg_inf, _ | _, Nneg_inf -> {nexp = Nneg_inf} (*TODO write a is_negative predicate*)
- | Npos_inf, _ | _, Npos_inf -> {nexp = Npos_inf}
+ | Nneg_inf, _ | _, Nneg_inf ->
+ (match nexp_negative n1, nexp_negative n2 with
+ | Yes, Yes -> {nexp = Npos_inf}
+ | _ -> {nexp = Nneg_inf})
+ | Npos_inf, _ | _, Npos_inf ->
+ (match nexp_negative n1, nexp_negative n2 with
+ | Yes, Yes -> assert false (*One of them must be Npos_inf, so nexp_negative horribly broken*)
+ | No, Yes | Yes, No -> {nexp = Nneg_inf}
+ | _ -> {nexp = Npos_inf})
| Ninexact, _ | _, Ninexact -> {nexp = Ninexact}
| Nconst i1, Nconst i2 -> mk_c (mult_big_int i1 i2)
| Nconst i1, N2n(n,Some i2) | N2n(n,Some i2),Nconst i1 ->
if eq_big_int i1 two
- then mk_2nc (normalize_nexp (mk_add n n_one)) (mult_big_int i1 i2)
+ then mk_2nc (normalize_n_rec true (mk_add n n_one)) (mult_big_int i1 i2)
else mk_c (mult_big_int i1 i2)
| Nconst i1, N2n(n,None) | N2n(n,None),Nconst i1 ->
if eq_big_int i1 two
- then mk_2n (normalize_nexp (mk_add n n_one))
+ then mk_2n (normalize_n_rec true (mk_add n n_one))
else mk_mult (mk_c i1) (mk_2n n)
| (Nmult (_, _), (Nvar _|Npow (_, _)|Nuvar _)) -> mk_mult n1' n2'
| Nvar _, Nuvar _ -> mk_mult n2' n1'
- | N2n(n1,Some i1),N2n(n2,Some i2) -> mk_2nc (normalize_nexp (mk_add n1 n2)) (mult_big_int i1 i2)
- | N2n(n1,_), N2n(n2,_) -> mk_2n (normalize_nexp (mk_add n1 n2))
+ | N2n(n1,Some i1),N2n(n2,Some i2) -> mk_2nc (normalize_n_rec true (mk_add n1 n2)) (mult_big_int i1 i2)
+ | N2n(n1,_), N2n(n2,_) -> mk_2n (normalize_n_rec true (mk_add n1 n2))
| N2n _, Nvar _ | N2n _, Nuvar _ | N2n _, Nmult _ | Nuvar _, N2n _ -> mk_mult n2' n1'
| Nuvar _, Nuvar _ | Nvar _, Nvar _ ->
(match compare n1' n2' with
@@ -623,10 +685,10 @@ let rec normalize_nexp n =
| _ -> mk_mult n1' n2')
| Nconst _, Nadd(n21,n22) | Nvar _,Nadd(n21,n22) | Nuvar _,Nadd(n21,n22) | N2n _, Nadd(n21,n22)
| Npow _,Nadd(n21,n22) | Nmult _, Nadd(n21,n22) ->
- normalize_nexp (mk_add (mk_mult n1' n21) (mk_mult n1' n21))
+ normalize_n_rec true (mk_add (mk_mult n1' n21) (mk_mult n1' n21))
| Nadd(n11,n12),Nconst _ | Nadd(n11,n12),Nvar _ | Nadd(n11,n12), Nuvar _ | Nadd(n11,n12), N2n _
| Nadd(n11,n12),Npow _ | Nadd(n11,n12), Nmult _->
- normalize_nexp (mk_add (mk_mult n11 n2') (mk_mult n12 n2'))
+ normalize_n_rec true (mk_add (mk_mult n11 n2') (mk_mult n12 n2'))
| Nmult(n11,n12), Nconst _ -> mk_mult (mk_mult n11 n2') (mk_mult n12 n2')
| Nconst i1, _ ->
if (eq_big_int i1 zero) then n1'
@@ -637,9 +699,9 @@ let rec normalize_nexp n =
else if (eq_big_int i1 one) then n1'
else mk_mult n2' n1'
| Nadd(n11,n12),Nadd(n21,n22) ->
- normalize_nexp (mk_add (mk_mult n11 n21)
+ normalize_n_rec true (mk_add (mk_mult n11 n21)
(mk_add (mk_mult n11 n22)
- (mk_add (mk_mult n12 n21) (mk_mult n12 n22))))
+ (mk_add (mk_mult n12 n21) (mk_mult n12 n22))))
| Nuvar _, Nvar _ | Nmult _, N2n _-> mk_mult n1' n2'
| Nuvar _, Nmult(n1,n2) | Nvar _, Nmult(n1,n2) ->
(match get_var n1, get_var n2 with
@@ -648,8 +710,8 @@ let rec normalize_nexp n =
| 0, Nuvar _ | 0, Nvar _ -> mk_mult n1 (mk_pow nv1 2)
| 0, Npow(n2',i) -> mk_mult n1 (mk_pow n2' (i+1))
| -1, Nuvar _ | -1, Nvar _ -> mk_mult n2' n1'
- | _,_ -> mk_mult (normalize_nexp (mk_mult n1 n1')) n2)
- | _ -> mk_mult (normalize_nexp (mk_mult n1 n1')) n2)
+ | _,_ -> mk_mult (normalize_n_rec true (mk_mult n1 n1')) n2)
+ | _ -> mk_mult (normalize_n_rec true (mk_mult n1 n1')) n2)
| (Npow (n1, i), (Nvar _ | Nuvar _)) ->
(match compare_nexps n1 n2' with
| 0 -> mk_pow n1 (i+1)
@@ -662,9 +724,9 @@ let rec normalize_nexp n =
(match compare_nexps nv1 nv2,n22.nexp with
| 0, Nuvar _ | 0, Nvar _ -> mk_mult n21 (mk_pow n1 (i+1))
| 0, Npow(_,i2) -> mk_mult n21 (mk_pow n1 (i+i2))
- | 1,Npow _ -> mk_mult (normalize_nexp (mk_mult n21 n1')) n22
+ | 1,Npow _ -> mk_mult (normalize_n_rec true (mk_mult n21 n1')) n22
| _ -> mk_mult n2' n1')
- | _ -> mk_mult (normalize_nexp (mk_mult n1' n21)) n22)
+ | _ -> mk_mult (normalize_n_rec true (mk_mult n1' n21)) n22)
| Nmult _ ,Nmult(n21,n22) -> mk_mult (mk_mult n21 n1') (mk_mult n22 n1')
| Nsub _, _ | _, Nsub _ ->
let _ = Printf.eprintf "nsub case still around %s\n" (n_to_string n) in assert false
@@ -673,6 +735,8 @@ let rec normalize_nexp n =
(* If things are normal, neg should be gone. *)
)
+let normalize_nexp = normalize_n_rec true
+
let int_to_nexp = mk_c_int
let v_count = ref 0
@@ -836,6 +900,62 @@ let nexp_one_more_than n1 n2 =
if (int_of_big_int i) = 1 then nexp_eq n1' n2 else false
| _ -> false
+
+let rec nexp_ge n1 n2 =
+ let n1,n2 = (normalize_nexp n1, normalize_nexp n2) in
+ if nexp_eq n1 n2
+ then Yes
+ else
+ match n1.nexp,n2.nexp with
+ | Nconst i, Nconst j | N2n(_,Some i), N2n(_,Some j)-> if ge_big_int i j then Yes else No
+ | Npos_inf, _ | _, Nneg_inf | Nuvar _, Npos_inf | Nneg_inf, Nuvar _ -> Yes
+ | Nneg_inf, _ | _, Npos_inf -> No
+ | Ninexact, _ | _, Ninexact -> Maybe
+ | N2n(n1,_), N2n(n2,_) -> nexp_ge n1 n2
+ | Nmult(n11,n12), Nmult(n21,n22) ->
+ if nexp_eq n12 n22
+ then nexp_ge n11 n21
+ else Maybe
+ | Nmult(n11,n12), _ ->
+ if nexp_eq n12 n2
+ then triple_negate (nexp_negative n11)
+ else Maybe
+ | _, Nmult(n21,n22) ->
+ if nexp_eq n1 n22
+ then nexp_negative n21
+ else Maybe
+ | Nadd(n11,n12),Nadd(n21,n22) ->
+ (match (nexp_ge n11 n21, nexp_ge n12 n22,
+ (nexp_negative n11, nexp_negative n12, nexp_negative n21, nexp_negative n22)) with
+ | Yes, Yes, (No, No, No, No) -> Yes
+ | No, No, (No, No, No, No) -> No
+ | _ -> Maybe)
+ | Nadd(n11,n12), _ ->
+ if nexp_eq n11 n2
+ then triple_negate (nexp_negative n12)
+ else if nexp_eq n12 n2
+ then triple_negate (nexp_negative n11)
+ else Maybe
+ | _ , Nadd(n21,n22) ->
+ if nexp_eq n1 n21
+ then nexp_negative n22
+ else if nexp_eq n1 n22
+ then nexp_negative n21
+ else Maybe
+ | Npow(n11,i1), Npow(n21, i2) ->
+ if nexp_eq n11 n21
+ then if i1 >= i2 then Yes else No
+ else Maybe
+ | Npow(n11,i1), _ ->
+ if nexp_eq n11 n2
+ then if i1 = 0 then No else Yes
+ else Maybe
+ | _, Npow(n21,i2) ->
+ if nexp_eq n1 n21
+ then if i2 = 0 then Yes else No
+ else Maybe
+ | _ -> Maybe
+
let equate_t (t_box : t) (t : t) : unit =
let t = resolve_tsubst t in
if t_box == t then ()
@@ -1014,12 +1134,12 @@ let rec contains_nuvar n cs = match cs with
if (contains_nuvar_nexp n nl || contains_nuvar_nexp n nr)
then co::(contains_nuvar n cs)
else contains_nuvar n cs
- | CondCons(so,conds,exps)::cs ->
+ | CondCons(so,kind,conds,exps)::cs ->
let conds' = contains_nuvar n conds in
let exps' = contains_nuvar n exps in
(match conds',exps' with
| [],[] -> contains_nuvar n cs
- | _ -> CondCons(so,conds',exps')::contains_nuvar n cs)
+ | _ -> CondCons(so,kind,conds',exps')::contains_nuvar n cs)
| BranchCons(so,b_cs)::cs ->
(match contains_nuvar n b_cs with
| [] -> contains_nuvar n cs
@@ -1852,7 +1972,7 @@ let rec cs_subst t_env cs =
InS(l,nexp,ns)::(cs_subst t_env cs)
| InS(l,n,ns)::cs -> InS(l,n_subst t_env n,ns)::(cs_subst t_env cs)
| Predicate(l, c)::cs -> Predicate(l, List.hd(cs_subst t_env [c]))::(cs_subst t_env cs)
- | CondCons(l,cs_p,cs_e)::cs -> CondCons(l,cs_subst t_env cs_p,cs_subst t_env cs_e)::(cs_subst t_env cs)
+ | CondCons(l,kind,cs_p,cs_e)::cs -> CondCons(l,kind,cs_subst t_env cs_p,cs_subst t_env cs_e)::(cs_subst t_env cs)
| BranchCons(l,bs)::cs -> BranchCons(l,cs_subst t_env bs)::(cs_subst t_env cs)
let subst (k_env : (Envmap.k * kind) list) (leave_imp:bool)
@@ -2257,7 +2377,7 @@ let rec type_consistent_internal co d_env enforce widen t1 cs1 t2 cs2 =
| Tapp("range",[TA_nexp b1;TA_nexp r1;]),Tapp("range",[TA_nexp b2;TA_nexp r2;]) ->
if (nexp_eq b1 b2)&&(nexp_eq r1 r2)
then (t2,csp)
- else (t1, csp@[LtEq(co,enforce,b1,b2);LtEq(co,enforce,r1,r2)])
+ else (t1, csp@[GtEq(co,enforce,b1,b2);LtEq(co,enforce,r1,r2)])
| Tapp("atom",[TA_nexp a]),Tapp("range",[TA_nexp b1; TA_nexp r1]) ->
(t1, csp@[GtEq(co,enforce,a,b1);LtEq(co,enforce,a,r1)])
| Tapp("range",[TA_nexp b1; TA_nexp r1]),Tapp("atom",[TA_nexp a]) ->
@@ -2630,7 +2750,7 @@ let rec get_all_nuvars_cs cs = match cs with
let n1s = get_nuvars n1 in
let n2s = get_nuvars n2 in
List.fold_right (fun n s -> Var_set.add n s) (n1s@n2s) s
- | CondCons(_,pats,exps)::cs ->
+ | CondCons(_,_,pats,exps)::cs ->
let s = get_all_nuvars_cs cs in
let ps = get_all_nuvars_cs pats in
let es = get_all_nuvars_cs exps in
@@ -2661,12 +2781,16 @@ let rec equate_nuvars in_env cs =
if (equate_n n1 n2) then equate cs else c::equate cs end
else c::equate cs
| _ -> c::equate cs)
- | CondCons(co,pats,exps):: cs ->
- let pats' = equate pats in
- let exps' = equate exps in
- (match pats',exps' with
- | [],[] -> equate cs
- | _ -> CondCons(co,pats',exps')::(equate cs))
+ | (CondCons(co,kind,pats,exps) as c):: cs ->
+ (match kind with
+ | Solo | Positive | Negative (*Wrong, but I'd like to get things working again*)->
+ (*let _ = Printf.eprintf "equate_nuvars: condcons: %s\n%!" (constraints_to_string [c]) in*)
+ let pats' = equate pats in
+ let exps' = equate exps in
+ (match pats',exps' with
+ | [],[] -> equate cs
+ | _ -> CondCons(co,kind,pats',exps')::(equate cs))
+ | _ -> CondCons(co,kind,pats,exps)::(equate cs))
| BranchCons(co,branches)::cs ->
let b' = equate branches in
if [] = b'
@@ -2763,14 +2887,18 @@ let rec simple_constraint_check in_env cs =
if equate_n n2' n1' then check cs else (GtEq(co,enforce,n1',n2')::check cs)
| Nneg_inf, _ -> eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires negative infinity to be greater than or equal to " ^ (n_to_string n2'))
| _,_ ->
- let new_n = normalize_nexp (mk_sub n1' n2') in
- (match new_n.nexp with
- | Nconst i ->
- if ge_big_int i zero
- then (check cs)
- else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires "
- ^ n_to_string new_n ^ " to be greater than or equal to 0, not " ^ string_of_big_int i)
- | _ -> GtEq(co,enforce,n1',n2')::(check cs)))
+ (match nexp_ge n1' n2' with
+ | Yes -> check cs
+ | No -> eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires " ^ n_to_string n1 ^ " to be greather than or equal to " ^ (n_to_string n2))
+ | Maybe ->
+ let new_n = normalize_nexp (mk_sub n1' n2') in
+ (match new_n.nexp with
+ | Nconst i ->
+ if ge_big_int i zero
+ then (check cs)
+ else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires "
+ ^ n_to_string new_n ^ " to be greater than or equal to 0, not " ^ string_of_big_int i)
+ | _ -> GtEq(co,enforce,n1',n2')::(check cs))))
| LtEq(co,enforce,n1,n2)::cs ->
(*let _ = Printf.eprintf "<= check, about to normalize_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *)
let n1',n2' = normalize_nexp n1,normalize_nexp n2 in
@@ -2784,15 +2912,24 @@ let rec simple_constraint_check in_env cs =
| _, Npos_inf | Nneg_inf, _ -> check cs
(*| Npos_inf, Nconst _ -> eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires infinity to be less than or equal to "
^ (n_to_string n2'))*)
- | _,_ -> LtEq(co,enforce,n1',n2')::(check cs))
- | CondCons(co,pats,exps):: cs ->
+ | _,_ ->
+ (match nexp_ge n2' n1' with
+ | Yes -> check cs
+ | No -> eq_error (get_c_loc co)
+ ("Type constraint mismatch: constraint arising from here requires " ^ n_to_string n1 ^
+ " to be less than or equal to " ^ (n_to_string n2))
+ | Maybe -> LtEq(co,enforce,n1',n2')::(check cs)))
+ | CondCons(co,kind,pats,exps):: cs ->
+ (match kind with
+ | Solo | Positive | Negative (*Wrong, but let's get things working again*)->
(*let _ = Printf.eprintf "Condcons check length pats %i, length exps %i\n" (List.length pats) (List.length exps) in*)
- let pats' = check pats in
- let exps' = check exps in
+ let pats' = check pats in
+ let exps' = check exps in
(*let _ = Printf.eprintf "Condcons after check length pats' %i, length exps' %i\n" (List.length pats') (List.length exps') in*)
- (match pats',exps' with
- | [],[] -> check cs
- | _ -> CondCons(co,pats',exps')::(check cs))
+ (match pats',exps' with
+ | [],[] -> check cs
+ | _ -> CondCons(co,kind,pats',exps')::(check cs))
+ | _ -> CondCons(co,kind,pats,exps)::(check cs))
| BranchCons(co,branches)::cs ->
(*let _ = Printf.eprintf "Branchcons check length branches %i\n" (List.length branches) in*)
let b' = check branches in
@@ -2825,7 +2962,7 @@ let rec constraint_size = function
| [] -> 0
| c::cs ->
(match c with
- | CondCons(_,ps,es) -> constraint_size ps + constraint_size es
+ | CondCons(_,_,ps,es) -> constraint_size ps + constraint_size es
| BranchCons(_,bs) -> constraint_size bs
| _ -> 1) + constraint_size cs
diff --git a/src/type_internal.mli b/src/type_internal.mli
index af2b99e0..a5457546 100644
--- a/src/type_internal.mli
+++ b/src/type_internal.mli
@@ -89,6 +89,7 @@ type constraint_origin =
| Specc of Parse_ast.l
type range_enforcement = Require | Guarantee
+type cond_kind = Positive | Negative | Solo | Switch
(* Constraints for nexps, plus the location which added the constraint *)
type nexp_range =
@@ -98,7 +99,7 @@ type nexp_range =
| In of constraint_origin * string * int list
| InS of constraint_origin * nexp * int list (* This holds the given value for string after a substitution *)
| Predicate of constraint_origin * nexp_range (* This will treat the inner constraint as holding in positive condcons positions : must be one of LtEq, Eq, or GtEq*)
- | CondCons of constraint_origin * nexp_range list * nexp_range list (* Constraints from one path from a conditional (pattern or if) and the constraints from that conditional *)
+ | CondCons of constraint_origin * cond_kind * nexp_range list * nexp_range list (* Constraints from one path from a conditional (pattern or if) and the constraints from that conditional *)
| BranchCons of constraint_origin * nexp_range list (* CondCons constraints from all branches of a conditional; list should be all CondCons *)
val get_c_loc : constraint_origin -> Parse_ast.l