diff options
| -rw-r--r-- | src/pretty_print.ml | 2 | ||||
| -rw-r--r-- | src/type_internal.ml | 110 |
2 files changed, 76 insertions, 36 deletions
diff --git a/src/pretty_print.ml b/src/pretty_print.ml index 5f86b2ed..cfefd4d0 100644 --- a/src/pretty_print.ml +++ b/src/pretty_print.ml @@ -544,7 +544,7 @@ let rec pp_format_t t = | Ttup(tups) -> "(T_tup [" ^ (list_format "; " pp_format_t tups) ^ "])" | Tapp(i,args) -> "(T_app \"" ^ i ^ "\" (T_args [" ^ list_format "; " pp_format_targ args ^ "]))" | Tabbrev(ti,ta) -> "(T_abbrev " ^ (pp_format_t ti) ^ " " ^ (pp_format_t ta) ^ ")" - | Tuvar(_) -> assert false (*"(T_var (Kid_aux (Var \"fresh_v\") Unknown))"*) + | Tuvar(_) -> "(T_var (Kid_aux (Var \"fresh_v\") Unknown))" and pp_format_targ = function | TA_typ t -> "(T_arg_typ " ^ pp_format_t t ^ ")" | TA_nexp n -> "(T_arg_nexp " ^ pp_format_n n ^ ")" diff --git a/src/type_internal.ml b/src/type_internal.ml index 0a499260..1aa31473 100644 --- a/src/type_internal.ml +++ b/src/type_internal.ml @@ -507,7 +507,7 @@ let mk_range n1 n2 = {t=Tapp("range",[TA_nexp {nexp=n1};TA_nexp {nexp=n2}])} let mk_vector typ order start size = {t=Tapp("vector",[TA_nexp {nexp=start}; TA_nexp {nexp=size}; TA_ord {order}; TA_typ typ])} let mk_bitwise_op name symb arity = (* XXX should be Ovar "o" but will not work currently *) - let ovar = Oinc in + let ovar = (*Oinc*) Ovar "o" in let vec_typ = mk_vector bit_t ovar (Nconst 0) (Nvar "n") in let args = Array.to_list (Array.make arity vec_typ) in let arg = if ((List.length args) = 1) then List.hd args else {t= Ttup args} in @@ -698,7 +698,7 @@ let rec t_to_typ t = | Ttup ts -> Typ_aux(Typ_tup(List.map t_to_typ ts),Parse_ast.Unknown) | Tapp(i,args) -> Typ_aux(Typ_app(Id_aux((Id i), Parse_ast.Unknown),List.map targ_to_typ_arg args),Parse_ast.Unknown) | Tabbrev(t,_) -> t_to_typ t - | Tuvar _ -> assert false + | Tuvar _ -> Typ_aux(Typ_var (Kid_aux((Var "fresh"),Parse_ast.Unknown)),Parse_ast.Unknown) and targ_to_typ_arg targ = Typ_arg_aux( (match targ with @@ -728,7 +728,7 @@ and o_to_order o = | Ovar i -> Ord_var (Kid_aux((Var i),Parse_ast.Unknown)) | Oinc -> Ord_inc | Odec -> Ord_dec - | Ouvar _ -> assert false), Parse_ast.Unknown) + | Ouvar _ -> Ord_var (Kid_aux((Var "fresh"),Parse_ast.Unknown))), Parse_ast.Unknown) let rec get_abbrev d_env t = @@ -860,7 +860,7 @@ let rec type_consistent_internal co d_env t1 cs1 t2 cs2 = let t2' = {t=Tapp("range",[TA_nexp b2;TA_nexp r2])} in equate_t t2 t2'; (t2,csp@[GtEq(co,b,b2);LtEq(co,r,r2)]) (*This and above should maybe be In constraints when co is patt and tuvar is an in*) - | t,Tuvar _ -> equate_t t2 t1; (t2,csp) + | t,Tuvar _ -> equate_t t2 t1; (t1,csp) | _,_ -> eq_error l ("Type mismatch found " ^ (t_to_string t1) ^ " but expected a " ^ (t_to_string t2)) and type_arg_eq co d_env ta1 ta2 = @@ -912,8 +912,8 @@ let rec type_coerce_internal co d_env t1 cs1 e t2 cs2 = [TA_nexp b2;TA_nexp r2;TA_ord o2;TA_typ t2i] -> (match o1.order,o2.order with | Oinc,Oinc | Odec,Odec -> () - | Oinc,Ouvar _ | Odec,Ouvar _ -> o2.order <- o1.order - | Ouvar _,Oinc | Ouvar _, Oinc -> o1.order <- o2.order + | Oinc,Ouvar _ | Odec,Ouvar _ -> equate_o o2 o1; + | Ouvar _,Oinc | Ouvar _, Odec -> equate_o o1 o2; | _,_ -> equate_o o1 o2); let cs = csp@[Eq(co,r1,r2)] in let t',cs' = type_consistent co d_env t1i t2i in @@ -953,8 +953,9 @@ let rec type_coerce_internal co d_env t1 cs1 e t2 cs2 = | "register",_ -> (match args1 with | [TA_typ t] -> + (*TODO Should this be an internal cast? Probably, make sure it doesn't interfere with the other internal cast and get removed *) (*let _ = Printf.printf "Adding cast to remove register read\n" in*) - let new_e = E_aux(E_cast(t_to_typ t,e),(l,Some(([],t),External None,[],(add_effect (BE_aux(BE_rreg, l)) pure_e)))) in + let new_e = E_aux(E_cast(t_to_typ unit_t,e),(l,Some(([],t),External None,[],(add_effect (BE_aux(BE_rreg, l)) pure_e)))) in type_coerce co d_env t new_e t2 | _ -> raise (Reporting_basic.err_unreachable l "register is not properly kinded")) | _,_ -> @@ -1012,33 +1013,72 @@ let rec type_coerce_internal co d_env t1 cs1 e t2 cs2 = | None -> eq_error l ("Type mismatch: " ^ (t_to_string t1) ^ " , " ^ (t_to_string t2))) | _,_ -> let t',cs = type_consistent co d_env t1 t2 in (t',cs,e) -and type_coerce co d_env t1 e t2 = type_coerce_internal co d_env t1 [] e t2 [] +and type_coerce co d_env t1 e t2 = type_coerce_internal co d_env t1 [] e t2 [];; -let rec simple_constraint_check cs = +let rec in_constraint_env = function + | [] -> [] + | InS(co,nexp,vals)::cs -> + (nexp,(List.map (fun c -> {nexp = Nconst c}) vals))::(in_constraint_env cs) + | In(co,i,vals)::cs -> + ({nexp = Nvar i},(List.map (fun c -> {nexp = Nconst c}) vals))::(in_constraint_env cs) + | _::cs -> in_constraint_env cs + +let rec contains_var nu n = + match n.nexp with + | Nvar _ | Nuvar _ -> nexp_eq_check nu n + | Nconst _ -> false + | Nadd(n1,n2) | Nmult(n1,n2) -> contains_var nu n1 || contains_var nu n2 + | Nneg n | N2n n -> contains_var nu n + +let rec contains_in_vars in_env n = + match in_env with + | [] -> None + | (ne,vals)::in_env -> + (match contains_in_vars in_env n with + | None -> if contains_var ne n then Some [ne,vals] else None + | Some(e_env) -> if contains_var ne n then Some((ne,vals)::e_env) else Some(e_env)) + +let rec subst_nuvars nu nc n = + match n.nexp with + | Nconst _ | Nvar _ -> n + | Nuvar _ -> if nexp_eq_check nu n then nc else n + | Nmult(n1,n2) -> {nexp=Nmult(subst_nuvars nu nc n1,subst_nuvars nu nc n2)} + | Nadd(n1,n2) -> {nexp=Nadd(subst_nuvars nu nc n1,subst_nuvars nu nc n2)} + | Nneg n -> {nexp= Nneg (subst_nuvars nu nc n)} + | N2n n -> {nexp = N2n (subst_nuvars nu nc n)} + + +let rec simple_constraint_check in_env cs = + let check = simple_constraint_check in_env in (* let _ = Printf.printf "simple_constraint_check\n" in *) match cs with | [] -> [] | Eq(co,n1,n2)::cs -> - (*let _ = Printf.printf "eq check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *) - let n1',n2' = eval_nexp n1,eval_nexp n2 in - (*let _ = Printf.printf "finished evaled to %s, %s\n" (n_to_string n1') (n_to_string n2') in *) - (match n1'.nexp,n2'.nexp with - | Nconst i1, Nconst i2 -> - if i1==i2 - then simple_constraint_check cs - else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires " - ^ string_of_int i1 ^ " to equal " ^ string_of_int i2) - | Nconst i, Nuvar u -> - if u.nin - then Eq(co,n1',n2')::(simple_constraint_check cs) - else begin equate_n n2' n1'; (simple_constraint_check cs) end - | Nuvar u, Nconst i -> - if u.nin - then Eq(co,n1',n2')::(simple_constraint_check cs) - else begin equate_n n1' n2'; (simple_constraint_check cs) end - | Nuvar u1, Nuvar u2 -> - resolve_nsubst n1; resolve_nsubst n2; equate_n n1' n2'; (simple_constraint_check cs) - | _,_ -> Eq(co,n1',n2')::(simple_constraint_check cs)) + let check_eq ok_to_set n1 n2 = + (*let _ = Printf.printf "eq check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *) + let n1',n2' = eval_nexp n1,eval_nexp n2 in + (*let _ = Printf.printf "finished evaled to %s, %s\n" (n_to_string n1') (n_to_string n2') in *) + (match n1'.nexp,n2'.nexp with + | Nconst i1, Nconst i2 -> + if i1==i2 then None + else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires " + ^ string_of_int i1 ^ " to equal " ^ string_of_int i2) + | Nconst i, Nuvar u -> + if not(u.nin) && ok_to_set + then begin equate_n n2' n1'; None end + else Some (Eq(co,n1',n2')) + | Nuvar u, Nconst i -> + if not(u.nin) && ok_to_set + then begin equate_n n1' n2'; None end + else Some (Eq(co,n1',n2')) + | Nuvar u1, Nuvar u2 -> + if ok_to_set + then begin resolve_nsubst n1; resolve_nsubst n2; equate_n n1' n2'; None end + else Some(Eq(co,n1',n2')) + | _,_ -> Some(Eq(co,n1',n2'))) in + (match check_eq true n1 n2 with + | None -> (check cs) + | Some(c) -> c::(check cs)) | GtEq(co,n1,n2)::cs -> (* let _ = Printf.printf ">= check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *) let n1',n2' = eval_nexp n1,eval_nexp n2 in @@ -1046,10 +1086,10 @@ let rec simple_constraint_check cs = (match n1'.nexp,n2'.nexp with | Nconst i1, Nconst i2 -> if i1>=i2 - then simple_constraint_check cs + then check cs else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires " ^ string_of_int i1 ^ " to be greater than or equal to " ^ string_of_int i2) - | _,_ -> GtEq(co,n1',n2')::(simple_constraint_check cs)) + | _,_ -> GtEq(co,n1',n2')::(check cs)) | LtEq(co,n1,n2)::cs -> (* let _ = Printf.printf "<= check, about to eval_nexp of %s, %s\n" (n_to_string n1) (n_to_string n2) in *) let n1',n2' = eval_nexp n1,eval_nexp n2 in @@ -1057,11 +1097,11 @@ let rec simple_constraint_check cs = (match n1'.nexp,n2'.nexp with | Nconst i1, Nconst i2 -> if i1<=i2 - then simple_constraint_check cs + then check cs else eq_error (get_c_loc co) ("Type constraint mismatch: constraint arising from here requires " ^ string_of_int i1 ^ " to be less than or equal to " ^ string_of_int i2) - | _,_ -> LtEq(co,n1',n2')::(simple_constraint_check cs)) - | x::cs -> x::(simple_constraint_check cs) + | _,_ -> LtEq(co,n1',n2')::(check cs)) + | x::cs -> x::(check cs) let rec resolve_in_constraints cs = cs @@ -1073,7 +1113,7 @@ let resolve_constraints cs = else let rec fix len cs = (* let _ = Printf.printf "Calling simple constraint check, fix check point is %i\n" len in *) - let cs' = simple_constraint_check cs in + let cs' = simple_constraint_check (in_constraint_env cs) cs in if len > (List.length cs') then fix (List.length cs') cs' else cs' in let complex_constraints = fix (List.length cs) cs in |
