summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/pretty_print.ml2
-rw-r--r--src/type_internal.ml110
2 files changed, 76 insertions, 36 deletions
diff --git a/src/pretty_print.ml b/src/pretty_print.ml
index 5f86b2ed..cfefd4d0 100644
--- a/src/pretty_print.ml
+++ b/src/pretty_print.ml
@@ -544,7 +544,7 @@ let rec pp_format_t t =
| Ttup(tups) -> "(T_tup [" ^ (list_format "; " pp_format_t tups) ^ "])"
| Tapp(i,args) -> "(T_app \"" ^ i ^ "\" (T_args [" ^ list_format "; " pp_format_targ args ^ "]))"
| Tabbrev(ti,ta) -> "(T_abbrev " ^ (pp_format_t ti) ^ " " ^ (pp_format_t ta) ^ ")"
- | Tuvar(_) -> assert false (*"(T_var (Kid_aux (Var \"fresh_v\") Unknown))"*)
+ | Tuvar(_) -> "(T_var (Kid_aux (Var \"fresh_v\") Unknown))"
and pp_format_targ = function
| TA_typ t -> "(T_arg_typ " ^ pp_format_t t ^ ")"
| TA_nexp n -> "(T_arg_nexp " ^ pp_format_n n ^ ")"
diff --git a/src/type_internal.ml b/src/type_internal.ml
index 0a499260..1aa31473 100644
--- a/src/type_internal.ml
+++ b/src/type_internal.ml
@@ -507,7 +507,7 @@ let mk_range n1 n2 = {t=Tapp("range",[TA_nexp {nexp=n1};TA_nexp {nexp=n2}])}
let mk_vector typ order start size = {t=Tapp("vector",[TA_nexp {nexp=start}; TA_nexp {nexp=size}; TA_ord {order}; TA_typ typ])}
let mk_bitwise_op name symb arity =
(* XXX should be Ovar "o" but will not work currently *)
- let ovar = Oinc in
+ let ovar = (*Oinc*) Ovar "o" in
let vec_typ = mk_vector bit_t ovar (Nconst 0) (Nvar "n") in
let args = Array.to_list (Array.make arity vec_typ) in
let arg = if ((List.length args) = 1) then List.hd args else {t= Ttup args} in
@@ -698,7 +698,7 @@ let rec t_to_typ t =
| Ttup ts -> Typ_aux(Typ_tup(List.map t_to_typ ts),Parse_ast.Unknown)
| Tapp(i,args) -> Typ_aux(Typ_app(Id_aux((Id i), Parse_ast.Unknown),List.map targ_to_typ_arg args),Parse_ast.Unknown)
| Tabbrev(t,_) -> t_to_typ t
- | Tuvar _ -> assert false
+ | Tuvar _ -> Typ_aux(Typ_var (Kid_aux((Var "fresh"),Parse_ast.Unknown)),Parse_ast.Unknown)
and targ_to_typ_arg targ =
Typ_arg_aux(
(match targ with
@@ -728,7 +728,7 @@ and o_to_order o =
| Ovar i -> Ord_var (Kid_aux((Var i),Parse_ast.Unknown))
| Oinc -> Ord_inc
| Odec -> Ord_dec
- | Ouvar _ -> assert false), Parse_ast.Unknown)
+ | Ouvar _ -> Ord_var (Kid_aux((Var "fresh"),Parse_ast.Unknown))), Parse_ast.Unknown)
let rec get_abbrev d_env t =
@@ -860,7 +860,7 @@ let rec type_consistent_internal co d_env t1 cs1 t2 cs2 =
let t2' = {t=Tapp("range",[TA_nexp b2;TA_nexp r2])} in
equate_t t2 t2';
(t2,csp@[GtEq(co,b,b2);LtEq(co,r,r2)]) (*This and above should maybe be In constraints when co is patt and tuvar is an in*)
- | t,Tuvar _ -> equate_t t2 t1; (t2,csp)
+ | t,Tuvar _ -> equate_t t2 t1; (t1,csp)
| _,_ -> eq_error l ("Type mismatch found " ^ (t_to_string t1) ^ " but expected a " ^ (t_to_string t2))
and type_arg_eq co d_env ta1 ta2 =
@@ -912,8 +912,8 @@ let rec type_coerce_internal co d_env t1 cs1 e t2 cs2 =
[TA_nexp b2;TA_nexp r2;TA_ord o2;TA_typ t2i] ->
(match o1.order,o2.order with
| Oinc,Oinc | Odec,Odec -> ()
- | Oinc,Ouvar _ | Odec,Ouvar _ -> o2.order <- o1.order
- | Ouvar _,Oinc | Ouvar _, Oinc -> o1.order <- o2.order
+ | Oinc,Ouvar _ | Odec,Ouvar _ -> equate_o o2 o1;
+ | Ouvar _,Oinc | Ouvar _, Odec -> equate_o o1 o2;
| _,_ -> equate_o o1 o2);
let cs = csp@[Eq(co,r1,r2)] in
let t',cs' = type_consistent co d_env t1i t2i in
@@ -953,8 +953,9 @@ let rec type_coerce_internal co d_env t1 cs1 e t2 cs2 =
| "register",_ ->
(match args1 with
| [TA_typ t] ->
+ (*TODO Should this be an internal cast? Probably, make sure it doesn't interfere with the other internal cast and get removed *)
(*let _ = Printf.printf "Adding cast to remove register read\n" in*)
- let new_e = E_aux(E_cast(t_to_typ t,e),(l,Some(([],t),External None,[],(add_effect (BE_aux(BE_rreg, l)) pure_e)))) in
+ let new_e = E_aux(E_cast(t_to_typ unit_t,e),(l,Some(([],t),External None,[],(add_effect (BE_aux(BE_rreg, l)) pure_e)))) in
type_coerce co d_env t new_e t2
| _ -> raise (Reporting_basic.err_unreachable l "register is not properly kinded"))
| _,_ ->
@@ -1012,33 +1013,72 @@ let rec type_coerce_internal co d_env t1 cs1 e t2 cs2 =
| None -> eq_error l ("Type mismatch: " ^ (t_to_string t1) ^ " , " ^ (t_to_string t2)))
| _,_ -> let t',cs = type_consistent co d_env t1 t2 in (t',cs,e)
-and type_coerce co d_env t1 e t2 = type_coerce_internal co d_env t1 [] e t2 []
+and type_coerce co d_env t1 e t2 = type_coerce_internal co d_env t1 [] e t2 [];;
-let rec simple_constraint_check cs =
+let rec in_constraint_env = function
+ | [] -> []
+ | InS(co,nexp,vals)::cs ->
+ (nexp,(List.map (fun c -> {nexp = Nconst c}) vals))::(in_constraint_env cs)
+ | In(co,i,vals)::cs ->
+ ({nexp = Nvar i},(List.map (fun c -> {nexp = Nconst c}) vals))::(in_constraint_env cs)
+ | _::cs -> in_constraint_env cs
+
+let rec contains_var nu n =
+ match n.nexp with
+ | Nvar _ | Nuvar _ -> nexp_eq_check nu n
+ | Nconst _ -> false
+ | Nadd(n1,n2) | Nmult(n1,n2) -> contains_var nu n1 || contains_var nu n2
+ | Nneg n | N2n n -> contains_var nu n
+
+let rec contains_in_vars in_env n =
+ match in_env with
+ | [] -> None
+ | (ne,vals)::in_env ->
+ (match contains_in_vars in_env n with
+ | None -> if contains_var ne n then Some [ne,vals] else None
+ | Some(e_env) -> if contains_var ne n then Some((ne,vals)::e_env) else Some(e_env))
+
+let rec subst_nuvars nu nc n =
+ match n.nexp with
+ | Nconst _ | Nvar _ -> n
+ | Nuvar _ -> if nexp_eq_check nu n then nc else n
+ | Nmult(n1,n2) -> {nexp=Nmult(subst_nuvars nu nc n1,subst_nuvars nu nc n2)}
+ | Nadd(n1,n2) -> {nexp=Nadd(subst_nuvars nu nc n1,subst_nuvars nu nc n2)}
+ | Nneg n -> {nexp= Nneg (subst_nuvars nu nc n)}
+ | N2n n -> {nexp = N2n (subst_nuvars nu nc n)}
+
+
+let rec simple_constraint_check in_env cs =
+ let check = simple_constraint_check in_env in
(* let _ = Printf.printf "simple_constraint_check\n" in *)
match cs with
| [] -> []
| Eq(co,n1,n2)::cs ->
- (*let _ = Printf.printf "eq check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *)
- let n1',n2' = eval_nexp n1,eval_nexp n2 in
- (*let _ = Printf.printf "finished evaled to %s, %s\n" (n_to_string n1') (n_to_string n2') in *)
- (match n1'.nexp,n2'.nexp with
- | Nconst i1, Nconst i2 ->
- if i1==i2
- then simple_constraint_check cs
- else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires "
- ^ string_of_int i1 ^ " to equal " ^ string_of_int i2)
- | Nconst i, Nuvar u ->
- if u.nin
- then Eq(co,n1',n2')::(simple_constraint_check cs)
- else begin equate_n n2' n1'; (simple_constraint_check cs) end
- | Nuvar u, Nconst i ->
- if u.nin
- then Eq(co,n1',n2')::(simple_constraint_check cs)
- else begin equate_n n1' n2'; (simple_constraint_check cs) end
- | Nuvar u1, Nuvar u2 ->
- resolve_nsubst n1; resolve_nsubst n2; equate_n n1' n2'; (simple_constraint_check cs)
- | _,_ -> Eq(co,n1',n2')::(simple_constraint_check cs))
+ let check_eq ok_to_set n1 n2 =
+ (*let _ = Printf.printf "eq check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *)
+ let n1',n2' = eval_nexp n1,eval_nexp n2 in
+ (*let _ = Printf.printf "finished evaled to %s, %s\n" (n_to_string n1') (n_to_string n2') in *)
+ (match n1'.nexp,n2'.nexp with
+ | Nconst i1, Nconst i2 ->
+ if i1==i2 then None
+ else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires "
+ ^ string_of_int i1 ^ " to equal " ^ string_of_int i2)
+ | Nconst i, Nuvar u ->
+ if not(u.nin) && ok_to_set
+ then begin equate_n n2' n1'; None end
+ else Some (Eq(co,n1',n2'))
+ | Nuvar u, Nconst i ->
+ if not(u.nin) && ok_to_set
+ then begin equate_n n1' n2'; None end
+ else Some (Eq(co,n1',n2'))
+ | Nuvar u1, Nuvar u2 ->
+ if ok_to_set
+ then begin resolve_nsubst n1; resolve_nsubst n2; equate_n n1' n2'; None end
+ else Some(Eq(co,n1',n2'))
+ | _,_ -> Some(Eq(co,n1',n2'))) in
+ (match check_eq true n1 n2 with
+ | None -> (check cs)
+ | Some(c) -> c::(check cs))
| GtEq(co,n1,n2)::cs ->
(* let _ = Printf.printf ">= check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *)
let n1',n2' = eval_nexp n1,eval_nexp n2 in
@@ -1046,10 +1086,10 @@ let rec simple_constraint_check cs =
(match n1'.nexp,n2'.nexp with
| Nconst i1, Nconst i2 ->
if i1>=i2
- then simple_constraint_check cs
+ then check cs
else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires "
^ string_of_int i1 ^ " to be greater than or equal to " ^ string_of_int i2)
- | _,_ -> GtEq(co,n1',n2')::(simple_constraint_check cs))
+ | _,_ -> GtEq(co,n1',n2')::(check cs))
| LtEq(co,n1,n2)::cs ->
(* let _ = Printf.printf "<= check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *)
let n1',n2' = eval_nexp n1,eval_nexp n2 in
@@ -1057,11 +1097,11 @@ let rec simple_constraint_check cs =
(match n1'.nexp,n2'.nexp with
| Nconst i1, Nconst i2 ->
if i1<=i2
- then simple_constraint_check cs
+ then check cs
else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires "
^ string_of_int i1 ^ " to be less than or equal to " ^ string_of_int i2)
- | _,_ -> LtEq(co,n1',n2')::(simple_constraint_check cs))
- | x::cs -> x::(simple_constraint_check cs)
+ | _,_ -> LtEq(co,n1',n2')::(check cs))
+ | x::cs -> x::(check cs)
let rec resolve_in_constraints cs = cs
@@ -1073,7 +1113,7 @@ let resolve_constraints cs =
else
let rec fix len cs =
(* let _ = Printf.printf "Calling simple constraint check, fix check point is %i\n" len in *)
- let cs' = simple_constraint_check cs in
+ let cs' = simple_constraint_check (in_constraint_env cs) cs in
if len > (List.length cs') then fix (List.length cs') cs'
else cs' in
let complex_constraints = fix (List.length cs) cs in