diff options
| author | Kathy Gray | 2015-05-01 19:22:32 +0100 |
|---|---|---|
| committer | Kathy Gray | 2015-05-01 19:22:32 +0100 |
| commit | cd5d04d2b5a44ed6c77cde0f4d69282938077465 (patch) | |
| tree | f20c35e5b7514e71afb6266e4a2a06c54eeb95a3 /src | |
| parent | cc3a823ed1c31d17609945219ad62296da1d76af (diff) | |
Fix pattern match bug with enumerated values
Diffstat (limited to 'src')
| -rw-r--r-- | src/lem_interp/interp.lem | 107 | ||||
| -rw-r--r-- | src/rewriter.ml | 13 | ||||
| -rw-r--r-- | src/type_check.ml | 7 |
3 files changed, 74 insertions, 53 deletions
diff --git a/src/lem_interp/interp.lem b/src/lem_interp/interp.lem index ae773ad1..a3bd294a 100644 --- a/src/lem_interp/interp.lem +++ b/src/lem_interp/interp.lem @@ -881,18 +881,22 @@ let env_to_let mode (LEnv _ env) (E_aux e annot) taint_env = end (* match_pattern returns a tuple of (pattern_matches? , pattern_passed_due_to_unknown?, env_of_pattern *) -val match_pattern : pat tannot -> value -> bool * bool * lenv -let rec match_pattern (P_aux p _) value_whole = +val match_pattern : top_level -> pat tannot -> value -> bool * bool * lenv +let rec match_pattern t_level (P_aux p (_, annot)) value_whole = + let (Env fdefs instrs default_dir lets regs ctors subregs aliases) = t_level in + let (t,tag,cs) = match annot with + | Just(t,tag,cs,e) -> (t,tag,cs) + | Nothing -> (T_var "fresh",Tag_empty,[]) end in let value = detaint value_whole in let taint_pat v = binary_taint (fun v _ -> v) v value_whole in match p with | P_lit(lit) -> if is_lit_vector lit then - let (V_vector n inc bits) = litV_to_vec lit IInc in + let (V_vector n inc bits) = litV_to_vec lit default_dir in match value with | V_lit litv -> if is_lit_vector litv then - let (V_vector n' inc' bits') = litV_to_vec litv IInc in + let (V_vector n' inc' bits') = litV_to_vec litv default_dir in if n=n' && inc = inc' then (foldr2 (fun l r rest -> (l = r) && rest) true bits bits',false, eenv) else (false,false,eenv) else (false,false,eenv) @@ -910,11 +914,11 @@ let rec match_pattern (P_aux p _) value_whole = end | P_wild -> (true,false,eenv) | P_as pat id -> - let (matched_p,used_unknown,bounds) = match_pattern pat value in + let (matched_p,used_unknown,bounds) = match_pattern t_level pat value in if matched_p then (matched_p,used_unknown,(add_to_env (id,value_whole) bounds)) else (false,false,eenv) - | P_typ typ pat -> match_pattern pat value_whole + | P_typ typ pat -> match_pattern t_level pat value_whole | P_id id -> (true, false, (LEnv 0 (Map.fromList [((get_id id),value_whole)]))) | P_app (Id_aux id _) pats -> match value with @@ -923,7 +927,7 @@ let rec match_pattern (P_aux p _) value_whole = then foldr2 (fun pat value (matched_p,used_unknown,bounds) -> if matched_p then - let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint_pat value) in + let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint_pat value) in (matched_p, (used_unknown || used_unknown'), (union_env new_bounds bounds)) else (false,false,eenv)) (true,false,eenv) pats vals else (false,false,eenv) @@ -932,7 +936,7 @@ let rec match_pattern (P_aux p _) value_whole = then foldr2 (fun pat value (matched_p,used_unknown,bounds) -> if matched_p then - let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint value r) in + let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint value r) in (matched_p, (used_unknown || used_unknown'), (union_env new_bounds bounds)) else (false,false,eenv)) (true,false,eenv) pats vals else (false,false,eenv) @@ -941,9 +945,18 @@ let rec match_pattern (P_aux p _) value_whole = then (match (pats,detaint v) with | ([],(V_lit (L_aux L_unit _))) -> (true,true,eenv) | ([P_aux (P_lit (L_aux L_unit _)) _],(V_lit (L_aux L_unit _))) -> (true,true,eenv) - | ([p],_) -> match_pattern p v + | ([p],_) -> match_pattern t_level p v | _ -> (false,false,eenv) end) else (false,false,eenv) + | V_lit (L_aux (L_num i) _) -> + match tag with + | Tag_enum -> + match Map.lookup (get_id (Id_aux id Unknown)) lets with + | Just(V_ctor _ t (C_Enum j) _) -> + if i = (integerFromNat j) then (true,false,eenv) + else (false,false,eenv) + | _ -> (false,false,eenv) end + | _ -> (false,false,eenv) end | V_unknown -> (true,true,eenv) | _ -> (false,false,eenv) end | P_record fpats _ -> @@ -955,7 +968,7 @@ let rec match_pattern (P_aux p _) value_whole = if matched_p then let (matched_p,used_unknown',new_bounds) = match in_env fvals_env (get_id id) with | Nothing -> (false,false,eenv) - | Just v -> match_pattern pat v end in + | Just v -> match_pattern t_level pat v end in (matched_p, (used_unknown || used_unknown'), (union_env new_bounds bounds)) else (false,false,eenv)) (true,false,eenv) fpats | V_unknown -> (true,true,eenv) @@ -968,7 +981,7 @@ let rec match_pattern (P_aux p _) value_whole = then foldr2 (fun pat value (matched_p,used_unknown,bounds) -> if matched_p then - let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint_pat value) in + let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint_pat value) in (matched_p, (used_unknown||used_unknown'), (union_env new_bounds bounds)) else (false,false,eenv)) (true,false,eenv) (if is_inc(dir) then pats else List.reverse pats) vals @@ -980,7 +993,7 @@ let rec match_pattern (P_aux p _) value_whole = (fun pat (i,matched_p,used_unknown,bounds) -> if matched_p then let (matched_p,used_unknown',new_bounds) = - match_pattern pat (match List.lookup i vals with + match_pattern t_level pat (match List.lookup i vals with | Nothing -> d | Just v -> (taint_pat v) end) in ((if is_inc(dir) then i+1 else i-1), @@ -1000,7 +1013,7 @@ let rec match_pattern (P_aux p _) value_whole = let i = natFromInteger i in if matched_p && i < v_len then let (matched_p,used_unknown',new_bounds) = - match_pattern pat (taint_pat (list_nth vals (if is_inc(dir) then i+n else i-n))) in + match_pattern t_level pat (taint_pat (list_nth vals (if is_inc(dir) then i+n else i-n))) in (matched_p,used_unknown||used_unknown',(union_env new_bounds bounds)) else (false,false,eenv)) (true,false,eenv) ipats @@ -1010,7 +1023,7 @@ let rec match_pattern (P_aux p _) value_whole = let i = natFromInteger i in if matched_p && i < m then let (matched_p,used_unknown',new_bounds) = - match_pattern pat (match List.lookup i vals with | Nothing -> d | Just v -> (taint_pat v) end) in + match_pattern t_level pat (match List.lookup i vals with | Nothing -> d | Just v -> (taint_pat v) end) in (matched_p,used_unknown||used_unknown',(union_env new_bounds bounds)) else (false,false,eenv)) (true,false,eenv) ipats @@ -1023,7 +1036,7 @@ let rec match_pattern (P_aux p _) value_whole = let (matched_p,used_unknown,bounds,remaining_vals) = List.foldl (fun (matched_p,used_unknown,bounds,r_vals) (P_aux pat (l,Just(t,_,_,_))) -> - let (matched_p,used_unknown',bounds',matcheds,r_vals) = vec_concat_match_plev pat r_vals inc l t in + let (matched_p,used_unknown',bounds',matcheds,r_vals) = vec_concat_match_plev t_level pat r_vals inc l t in (matched_p,(used_unknown || used_unknown'),(union_env bounds' bounds),r_vals)) (true,false,eenv,vals) pats in if matched_p && ([] = remaining_vals) then (matched_p,used_unknown,bounds) else (false,false,eenv) | V_unknown -> (true,true,eenv) @@ -1035,7 +1048,7 @@ let rec match_pattern (P_aux p _) value_whole = if ((List.length pats)= (List.length vals)) then foldr2 (fun pat v (matched_p,used_unknown,bounds) -> if matched_p then - let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint_pat v) in + let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint_pat v) in (matched_p,used_unknown ||used_unknown', (union_env new_bounds bounds)) else (false,false,eenv)) (true,false,eenv) pats vals @@ -1049,7 +1062,7 @@ let rec match_pattern (P_aux p _) value_whole = if ((List.length pats)= (List.length vals)) then foldr2 (fun pat v (matched_p,used_unknown,bounds) -> if matched_p then - let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint_pat v) in + let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint_pat v) in (matched_p,used_unknown|| used_unknown', (union_env new_bounds bounds)) else (false,false,eenv)) (true,false,eenv) pats vals @@ -1058,19 +1071,19 @@ let rec match_pattern (P_aux p _) value_whole = | _ -> (false,false,eenv) end end -and vec_concat_match_plev pat r_vals dir l t = +and vec_concat_match_plev t_level pat r_vals dir l t = match pat with | P_lit (L_aux (L_bin bin_string) l') -> let bin_chars = toCharList bin_string in let binpats = List.map (fun b -> P_aux (match b with | #'0' -> P_lit (L_aux L_zero l') | #'1' -> P_lit (L_aux L_one l')end) (l',Nothing)) bin_chars in - vec_concat_match binpats r_vals - | P_vector pats -> vec_concat_match pats r_vals + vec_concat_match t_level binpats r_vals + | P_vector pats -> vec_concat_match t_level pats r_vals | P_id id -> (match t with | T_app "vector" (T_args [T_arg_nexp _;T_arg_nexp (Ne_const i);_;_]) -> let wilds = List.genlist (fun _ -> P_aux P_wild (l,Nothing)) (natFromInteger i) in - let (matched_p,used_unknown,bounds,matcheds,r_vals) = vec_concat_match wilds r_vals in + let (matched_p,used_unknown,bounds,matcheds,r_vals) = vec_concat_match t_level wilds r_vals in if matched_p then (matched_p, used_unknown, (add_to_env (id,(V_vector (if is_inc(dir) then 0 else (List.length matcheds)) dir matcheds)) @@ -1083,59 +1096,59 @@ and vec_concat_match_plev pat r_vals dir l t = | P_wild -> (match t with | T_app "vector" (T_args [T_arg_nexp _;T_arg_nexp (Ne_const i);_;_]) -> let wilds = List.genlist (fun _ -> P_aux P_wild (l,Nothing)) (natFromInteger i) in - vec_concat_match wilds r_vals + vec_concat_match t_level wilds r_vals | T_app "vector" (T_args [T_arg_nexp _;T_arg_nexp nc;_;_]) -> (false,false,eenv,[],[]) (*TODO see if can have some constraint bounds here*) | _ -> (false,false,eenv,[],[]) end) | P_as (P_aux pat (l',Just(t,_,_,_))) id -> - let (matched_p, used_unknown, bounds,matcheds,r_vals) = vec_concat_match_plev pat r_vals dir l t in + let (matched_p, used_unknown, bounds,matcheds,r_vals) = vec_concat_match_plev t_level pat r_vals dir l t in if matched_p then (matched_p, used_unknown, (add_to_env (id,V_vector (if is_inc(dir) then 0 else (List.length matcheds)) dir matcheds) bounds), matcheds,r_vals) else (false,false,eenv,[],[]) - | P_typ _ (P_aux p (l',Just(t',_,_,_))) -> vec_concat_match_plev p r_vals dir l t + | P_typ _ (P_aux p (l',Just(t',_,_,_))) -> vec_concat_match_plev t_level p r_vals dir l t | _ -> (false,false,eenv,[],[]) end (*TODO Need to support indexed here, skipping intermediate numbers but consumming r_vals, and as *) -and vec_concat_match pats r_vals = +and vec_concat_match t_level pats r_vals = match pats with | [] -> (true,false,eenv,[],r_vals) | pat::pats -> match r_vals with | [] -> (false,false,eenv,[],[]) | r::r_vals -> - let (matched_p,used_unknown,new_bounds) = match_pattern pat r in + let (matched_p,used_unknown,new_bounds) = match_pattern t_level pat r in if matched_p then - let (matched_p,used_unknown',bounds,matcheds,r_vals) = vec_concat_match pats r_vals in + let (matched_p,used_unknown',bounds,matcheds,r_vals) = vec_concat_match t_level pats r_vals in (matched_p, used_unknown||used_unknown',(union_env new_bounds bounds),r :: matcheds,r_vals) else (false,false,eenv,[],[]) end end (* Returns all matches using Unknown until either there are no more matches or a pattern matches with no Unknowns used *) -val find_funcl : list (funcl tannot) -> value -> list (lenv * bool * (exp tannot)) -let rec find_funcl funcls value = +val find_funcl : top_level -> list (funcl tannot) -> value -> list (lenv * bool * (exp tannot)) +let rec find_funcl t_level funcls value = match funcls with | [] -> [] | (FCL_aux (FCL_Funcl id pat exp) _)::funcls -> - let (is_matching,used_unknown,env) = match_pattern pat value in + let (is_matching,used_unknown,env) = match_pattern t_level pat value in if (is_matching && used_unknown) - then (env,used_unknown,exp)::(find_funcl funcls value) + then (env,used_unknown,exp)::(find_funcl t_level funcls value) else if is_matching then [(env,used_unknown,exp)] - else find_funcl funcls value + else find_funcl t_level funcls value end (*see above comment*) -val find_case : list (pexp tannot) -> value -> list (lenv * bool * (exp tannot)) -let rec find_case pexps value = +val find_case : top_level -> list (pexp tannot) -> value -> list (lenv * bool * (exp tannot)) +let rec find_case t_level pexps value = match pexps with | [] -> [] | (Pat_aux (Pat_exp p e) _)::pexps -> - let (is_matching,used_unknown,env) = match_pattern p value in + let (is_matching,used_unknown,env) = match_pattern t_level p value in if (is_matching && used_unknown) - then (env,used_unknown,e)::find_case pexps value + then (env,used_unknown,e)::find_case t_level pexps value else if is_matching then [(env,used_unknown,e)] - else find_case pexps value + else find_case t_level pexps value end val interp_main : interp_mode -> top_level -> lenv -> lmem -> (exp tannot) -> (outcome * lmem * lenv) @@ -1349,7 +1362,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = resolve_outcome (interp_main mode t_level l_env l_mem exp) (fun v lm le -> - match find_case pats v with + match find_case t_level pats v with | [] -> (Error l ("No matching patterns in case for value " ^ (string_of_value v)),lm,le) | [(env,_,exp)] -> if mode.eager_eval @@ -1740,7 +1753,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = | Tag_global -> (match Map.lookup name fdefs with | Just(funcls) -> - (match find_funcl funcls v with + (match find_funcl t_level funcls v with | [] -> (Error l ("No matching pattern for function " ^ name ^ " on value " ^ (string_of_value v)),l_mem,l_env) @@ -1765,7 +1778,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = | Tag_empty -> (match Map.lookup name fdefs with | Just(funcls) -> - (match find_funcl funcls v with + (match find_funcl t_level funcls v with | [] -> (Error l (String.stringAppend "No matching pattern for function " name ),l_mem,l_env) | [(env,used_unknown,exp)] -> @@ -1784,7 +1797,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = | Tag_spec -> (match Map.lookup name fdefs with | Just(funcls) -> - (match find_funcl funcls v with + (match find_funcl t_level funcls v with | [] -> (Error l (String.stringAppend "No matching pattern for function " name ),l_mem,l_env) | [(env,used_unknown,exp)] -> @@ -1851,7 +1864,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = (match Map.lookup name fdefs with | Nothing -> (Error l ("Internal error, no function def for " ^ name),lm,le) | Just (funcls) -> - (match find_funcl funcls (V_tuple [lv;rv]) with + (match find_funcl t_level funcls (V_tuple [lv;rv]) with | [] -> (Error l ("No matching pattern for function " ^ name),lm,l_env) | [(env,used_unknown,exp)] -> resolve_outcome @@ -1869,7 +1882,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = (match Map.lookup name fdefs with | Nothing -> (Error l ("Internal error, no function def for " ^ name),lm,le) | Just (funcls) -> - (match find_funcl funcls (V_tuple [lv;rv]) with + (match find_funcl t_level funcls (V_tuple [lv;rv]) with | [] -> (Error l ("No matching pattern for function " ^ name),lm,l_env) | [(env,used_unknown,exp)] -> resolve_outcome @@ -1885,7 +1898,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) = (match Map.lookup name fdefs with | Nothing -> (Error l ("No function definition found for " ^ name),lm,le) | Just (funcls) -> - (match find_funcl funcls (V_tuple [lv;rv]) with + (match find_funcl t_level funcls (V_tuple [lv;rv]) with | [] -> (Error l ("No matching pattern for function " ^ name),lm,l_env) | [(env,used_unknown,exp)] -> resolve_outcome @@ -2112,7 +2125,7 @@ and create_write_message_or_update mode t_level value l_env l_mem is_top_level ( | V_tuple vs -> V_tuple (vs ++ [value]) | V_lit (L_aux L_unit _) -> value | v -> V_tuple [v;value] end in - (match find_funcl funcls new_vals with + (match find_funcl t_level funcls new_vals with | [] -> ((Error l ("No matching pattern for function " ^ name ^ " on value " ^ (string_of_value new_vals)),l_mem,l_env),Nothing) | [(env,used_unknown,exp)] -> @@ -2386,7 +2399,7 @@ and interp_letbind mode t_level l_env l_mem (LB_aux lbind (l,annot)) = | LB_val_explicit t pat exp -> match (interp_main mode t_level l_env l_mem exp) with | (Value v,lm,le) -> - (match match_pattern pat v with + (match match_pattern t_level pat v with | (true,used_unknown,env) -> ((Value (V_lit (L_aux L_unit l)), lm, (union_env env l_env)),Nothing) | _ -> ((Error l "Pattern in letbind did not match value",lm,le),Nothing) end) | (Action a s,lm,le) -> ((Action a s,lm,le),(Just (fun e ->(LB_aux (LB_val_explicit t pat e) (l,annot))))) @@ -2394,7 +2407,7 @@ and interp_letbind mode t_level l_env l_mem (LB_aux lbind (l,annot)) = | LB_val_implicit pat exp -> match (interp_main mode t_level l_env l_mem exp) with | (Value v,lm,le) -> - (match match_pattern pat v with + (match match_pattern t_level pat v with | (true,used_unknown,env) -> ((Value (V_lit (L_aux L_unit l)), lm, (union_env env l_env)),Nothing) | _ -> ((Error l "Pattern in letbind did not match value",lm,le),Nothing) end) | (Action a s,lm,le) -> ((Action a s,lm,le),(Just (fun e -> (LB_aux (LB_val_implicit pat e) (l,annot))))) diff --git a/src/rewriter.ml b/src/rewriter.ml index 79f92e04..4fad78e7 100644 --- a/src/rewriter.ml +++ b/src/rewriter.ml @@ -117,13 +117,14 @@ let rec rewrite_exp (E_aux (exp,(l,annot))) = (*TODO should pass d_env into here so that I can look at the abbreviations if there are any here*) | Tapp("vector",[TA_nexp n1;TA_nexp nw1;TA_ord o1;_]), Tapp("vector",[TA_nexp n2;TA_nexp nw2;TA_ord o2;_]) -> - else (match n1.nexp with + (match n1.nexp with | Nconst i1 -> if nexp_eq n1 n2 then new_exp else rewrap (E_cast (t_to_typ t,new_exp)) - | Nadd _ | Nsub -> (match o1.order with - | O_inc -> new_exp - | O_dec -> if nexp_one_more_than nw1 n1 - then rewrap (E_cast (Typ_var (Kid_aux (Var "length") Unknown), new_exp)) - else new_exp) + | Nadd _ | Nsub _ -> (match o1.order with + | Oinc -> new_exp + | Odec -> + if nexp_one_more_than nw1 n1 + then rewrap (E_cast (Typ_aux (Typ_var (Kid_aux((Var "length"), Unknown)), Unknown), new_exp)) + else new_exp) | _ -> new_exp) | _ -> new_exp)) | E_internal_exp (l,impl) -> diff --git a/src/type_check.ml b/src/type_check.ml index 36cd4fff..8523ce85 100644 --- a/src/type_check.ml +++ b/src/type_check.ml @@ -265,6 +265,13 @@ let rec check_pattern envs emp_tag expect_t (P_aux(p,(l,annot))) : ((tannot pat) then typ_error l ("Constructor " ^ i ^ " expects arguments of type " ^ (t_to_string t) ^ ", found none") else default | _ -> raise (Reporting_basic.err_unreachable l "Constructor tannot does not have function type")) + | Some(Base((params,t),Enum,cs,ef,bounds)) -> + let t,cs,ef,_ = subst params false t cs ef in + if conforms_to_t d_env false false t expect_t + then + let tp,cp = type_consistent (Expr l) d_env Guarantee false t expect_t in + (P_aux(P_app(id,[]),(l,tag_annot t Enum)),Envmap.empty,cs@cp,bounds,tp) + else default | _ -> default) | P_app(id,pats) -> let i = id_to_string id in |
