summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKathy Gray2015-05-01 19:22:32 +0100
committerKathy Gray2015-05-01 19:22:32 +0100
commitcd5d04d2b5a44ed6c77cde0f4d69282938077465 (patch)
treef20c35e5b7514e71afb6266e4a2a06c54eeb95a3 /src
parentcc3a823ed1c31d17609945219ad62296da1d76af (diff)
Fix pattern match bug with enumerated values
Diffstat (limited to 'src')
-rw-r--r--src/lem_interp/interp.lem107
-rw-r--r--src/rewriter.ml13
-rw-r--r--src/type_check.ml7
3 files changed, 74 insertions, 53 deletions
diff --git a/src/lem_interp/interp.lem b/src/lem_interp/interp.lem
index ae773ad1..a3bd294a 100644
--- a/src/lem_interp/interp.lem
+++ b/src/lem_interp/interp.lem
@@ -881,18 +881,22 @@ let env_to_let mode (LEnv _ env) (E_aux e annot) taint_env =
end
(* match_pattern returns a tuple of (pattern_matches? , pattern_passed_due_to_unknown?, env_of_pattern *)
-val match_pattern : pat tannot -> value -> bool * bool * lenv
-let rec match_pattern (P_aux p _) value_whole =
+val match_pattern : top_level -> pat tannot -> value -> bool * bool * lenv
+let rec match_pattern t_level (P_aux p (_, annot)) value_whole =
+ let (Env fdefs instrs default_dir lets regs ctors subregs aliases) = t_level in
+ let (t,tag,cs) = match annot with
+ | Just(t,tag,cs,e) -> (t,tag,cs)
+ | Nothing -> (T_var "fresh",Tag_empty,[]) end in
let value = detaint value_whole in
let taint_pat v = binary_taint (fun v _ -> v) v value_whole in
match p with
| P_lit(lit) ->
if is_lit_vector lit then
- let (V_vector n inc bits) = litV_to_vec lit IInc in
+ let (V_vector n inc bits) = litV_to_vec lit default_dir in
match value with
| V_lit litv ->
if is_lit_vector litv then
- let (V_vector n' inc' bits') = litV_to_vec litv IInc in
+ let (V_vector n' inc' bits') = litV_to_vec litv default_dir in
if n=n' && inc = inc' then (foldr2 (fun l r rest -> (l = r) && rest) true bits bits',false, eenv)
else (false,false,eenv)
else (false,false,eenv)
@@ -910,11 +914,11 @@ let rec match_pattern (P_aux p _) value_whole =
end
| P_wild -> (true,false,eenv)
| P_as pat id ->
- let (matched_p,used_unknown,bounds) = match_pattern pat value in
+ let (matched_p,used_unknown,bounds) = match_pattern t_level pat value in
if matched_p then
(matched_p,used_unknown,(add_to_env (id,value_whole) bounds))
else (false,false,eenv)
- | P_typ typ pat -> match_pattern pat value_whole
+ | P_typ typ pat -> match_pattern t_level pat value_whole
| P_id id -> (true, false, (LEnv 0 (Map.fromList [((get_id id),value_whole)])))
| P_app (Id_aux id _) pats ->
match value with
@@ -923,7 +927,7 @@ let rec match_pattern (P_aux p _) value_whole =
then foldr2
(fun pat value (matched_p,used_unknown,bounds) ->
if matched_p then
- let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint_pat value) in
+ let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint_pat value) in
(matched_p, (used_unknown || used_unknown'), (union_env new_bounds bounds))
else (false,false,eenv)) (true,false,eenv) pats vals
else (false,false,eenv)
@@ -932,7 +936,7 @@ let rec match_pattern (P_aux p _) value_whole =
then foldr2
(fun pat value (matched_p,used_unknown,bounds) ->
if matched_p then
- let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint value r) in
+ let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint value r) in
(matched_p, (used_unknown || used_unknown'), (union_env new_bounds bounds))
else (false,false,eenv)) (true,false,eenv) pats vals
else (false,false,eenv)
@@ -941,9 +945,18 @@ let rec match_pattern (P_aux p _) value_whole =
then (match (pats,detaint v) with
| ([],(V_lit (L_aux L_unit _))) -> (true,true,eenv)
| ([P_aux (P_lit (L_aux L_unit _)) _],(V_lit (L_aux L_unit _))) -> (true,true,eenv)
- | ([p],_) -> match_pattern p v
+ | ([p],_) -> match_pattern t_level p v
| _ -> (false,false,eenv) end)
else (false,false,eenv)
+ | V_lit (L_aux (L_num i) _) ->
+ match tag with
+ | Tag_enum ->
+ match Map.lookup (get_id (Id_aux id Unknown)) lets with
+ | Just(V_ctor _ t (C_Enum j) _) ->
+ if i = (integerFromNat j) then (true,false,eenv)
+ else (false,false,eenv)
+ | _ -> (false,false,eenv) end
+ | _ -> (false,false,eenv) end
| V_unknown -> (true,true,eenv)
| _ -> (false,false,eenv) end
| P_record fpats _ ->
@@ -955,7 +968,7 @@ let rec match_pattern (P_aux p _) value_whole =
if matched_p then
let (matched_p,used_unknown',new_bounds) = match in_env fvals_env (get_id id) with
| Nothing -> (false,false,eenv)
- | Just v -> match_pattern pat v end in
+ | Just v -> match_pattern t_level pat v end in
(matched_p, (used_unknown || used_unknown'), (union_env new_bounds bounds))
else (false,false,eenv)) (true,false,eenv) fpats
| V_unknown -> (true,true,eenv)
@@ -968,7 +981,7 @@ let rec match_pattern (P_aux p _) value_whole =
then foldr2
(fun pat value (matched_p,used_unknown,bounds) ->
if matched_p then
- let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint_pat value) in
+ let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint_pat value) in
(matched_p, (used_unknown||used_unknown'), (union_env new_bounds bounds))
else (false,false,eenv))
(true,false,eenv) (if is_inc(dir) then pats else List.reverse pats) vals
@@ -980,7 +993,7 @@ let rec match_pattern (P_aux p _) value_whole =
(fun pat (i,matched_p,used_unknown,bounds) ->
if matched_p
then let (matched_p,used_unknown',new_bounds) =
- match_pattern pat (match List.lookup i vals with
+ match_pattern t_level pat (match List.lookup i vals with
| Nothing -> d
| Just v -> (taint_pat v) end) in
((if is_inc(dir) then i+1 else i-1),
@@ -1000,7 +1013,7 @@ let rec match_pattern (P_aux p _) value_whole =
let i = natFromInteger i in
if matched_p && i < v_len then
let (matched_p,used_unknown',new_bounds) =
- match_pattern pat (taint_pat (list_nth vals (if is_inc(dir) then i+n else i-n))) in
+ match_pattern t_level pat (taint_pat (list_nth vals (if is_inc(dir) then i+n else i-n))) in
(matched_p,used_unknown||used_unknown',(union_env new_bounds bounds))
else (false,false,eenv))
(true,false,eenv) ipats
@@ -1010,7 +1023,7 @@ let rec match_pattern (P_aux p _) value_whole =
let i = natFromInteger i in
if matched_p && i < m then
let (matched_p,used_unknown',new_bounds) =
- match_pattern pat (match List.lookup i vals with | Nothing -> d | Just v -> (taint_pat v) end) in
+ match_pattern t_level pat (match List.lookup i vals with | Nothing -> d | Just v -> (taint_pat v) end) in
(matched_p,used_unknown||used_unknown',(union_env new_bounds bounds))
else (false,false,eenv))
(true,false,eenv) ipats
@@ -1023,7 +1036,7 @@ let rec match_pattern (P_aux p _) value_whole =
let (matched_p,used_unknown,bounds,remaining_vals) =
List.foldl
(fun (matched_p,used_unknown,bounds,r_vals) (P_aux pat (l,Just(t,_,_,_))) ->
- let (matched_p,used_unknown',bounds',matcheds,r_vals) = vec_concat_match_plev pat r_vals inc l t in
+ let (matched_p,used_unknown',bounds',matcheds,r_vals) = vec_concat_match_plev t_level pat r_vals inc l t in
(matched_p,(used_unknown || used_unknown'),(union_env bounds' bounds),r_vals)) (true,false,eenv,vals) pats in
if matched_p && ([] = remaining_vals) then (matched_p,used_unknown,bounds) else (false,false,eenv)
| V_unknown -> (true,true,eenv)
@@ -1035,7 +1048,7 @@ let rec match_pattern (P_aux p _) value_whole =
if ((List.length pats)= (List.length vals))
then foldr2
(fun pat v (matched_p,used_unknown,bounds) -> if matched_p then
- let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint_pat v) in
+ let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint_pat v) in
(matched_p,used_unknown ||used_unknown', (union_env new_bounds bounds))
else (false,false,eenv))
(true,false,eenv) pats vals
@@ -1049,7 +1062,7 @@ let rec match_pattern (P_aux p _) value_whole =
if ((List.length pats)= (List.length vals))
then foldr2
(fun pat v (matched_p,used_unknown,bounds) -> if matched_p then
- let (matched_p,used_unknown',new_bounds) = match_pattern pat (taint_pat v) in
+ let (matched_p,used_unknown',new_bounds) = match_pattern t_level pat (taint_pat v) in
(matched_p,used_unknown|| used_unknown', (union_env new_bounds bounds))
else (false,false,eenv))
(true,false,eenv) pats vals
@@ -1058,19 +1071,19 @@ let rec match_pattern (P_aux p _) value_whole =
| _ -> (false,false,eenv) end
end
-and vec_concat_match_plev pat r_vals dir l t =
+and vec_concat_match_plev t_level pat r_vals dir l t =
match pat with
| P_lit (L_aux (L_bin bin_string) l') ->
let bin_chars = toCharList bin_string in
let binpats = List.map
(fun b -> P_aux (match b with
| #'0' -> P_lit (L_aux L_zero l') | #'1' -> P_lit (L_aux L_one l')end) (l',Nothing)) bin_chars in
- vec_concat_match binpats r_vals
- | P_vector pats -> vec_concat_match pats r_vals
+ vec_concat_match t_level binpats r_vals
+ | P_vector pats -> vec_concat_match t_level pats r_vals
| P_id id -> (match t with
| T_app "vector" (T_args [T_arg_nexp _;T_arg_nexp (Ne_const i);_;_]) ->
let wilds = List.genlist (fun _ -> P_aux P_wild (l,Nothing)) (natFromInteger i) in
- let (matched_p,used_unknown,bounds,matcheds,r_vals) = vec_concat_match wilds r_vals in
+ let (matched_p,used_unknown,bounds,matcheds,r_vals) = vec_concat_match t_level wilds r_vals in
if matched_p
then (matched_p, used_unknown,
(add_to_env (id,(V_vector (if is_inc(dir) then 0 else (List.length matcheds)) dir matcheds))
@@ -1083,59 +1096,59 @@ and vec_concat_match_plev pat r_vals dir l t =
| P_wild -> (match t with
| T_app "vector" (T_args [T_arg_nexp _;T_arg_nexp (Ne_const i);_;_]) ->
let wilds = List.genlist (fun _ -> P_aux P_wild (l,Nothing)) (natFromInteger i) in
- vec_concat_match wilds r_vals
+ vec_concat_match t_level wilds r_vals
| T_app "vector" (T_args [T_arg_nexp _;T_arg_nexp nc;_;_]) ->
(false,false,eenv,[],[]) (*TODO see if can have some constraint bounds here*)
| _ -> (false,false,eenv,[],[]) end)
| P_as (P_aux pat (l',Just(t,_,_,_))) id ->
- let (matched_p, used_unknown, bounds,matcheds,r_vals) = vec_concat_match_plev pat r_vals dir l t in
+ let (matched_p, used_unknown, bounds,matcheds,r_vals) = vec_concat_match_plev t_level pat r_vals dir l t in
if matched_p
then (matched_p, used_unknown,
(add_to_env (id,V_vector (if is_inc(dir) then 0 else (List.length matcheds)) dir matcheds) bounds),
matcheds,r_vals)
else (false,false,eenv,[],[])
- | P_typ _ (P_aux p (l',Just(t',_,_,_))) -> vec_concat_match_plev p r_vals dir l t
+ | P_typ _ (P_aux p (l',Just(t',_,_,_))) -> vec_concat_match_plev t_level p r_vals dir l t
| _ -> (false,false,eenv,[],[]) end
(*TODO Need to support indexed here, skipping intermediate numbers but consumming r_vals, and as *)
-and vec_concat_match pats r_vals =
+and vec_concat_match t_level pats r_vals =
match pats with
| [] -> (true,false,eenv,[],r_vals)
| pat::pats -> match r_vals with
| [] -> (false,false,eenv,[],[])
| r::r_vals ->
- let (matched_p,used_unknown,new_bounds) = match_pattern pat r in
+ let (matched_p,used_unknown,new_bounds) = match_pattern t_level pat r in
if matched_p then
- let (matched_p,used_unknown',bounds,matcheds,r_vals) = vec_concat_match pats r_vals in
+ let (matched_p,used_unknown',bounds,matcheds,r_vals) = vec_concat_match t_level pats r_vals in
(matched_p, used_unknown||used_unknown',(union_env new_bounds bounds),r :: matcheds,r_vals)
else (false,false,eenv,[],[]) end
end
(* Returns all matches using Unknown until either there are no more matches or a pattern matches with no Unknowns used *)
-val find_funcl : list (funcl tannot) -> value -> list (lenv * bool * (exp tannot))
-let rec find_funcl funcls value =
+val find_funcl : top_level -> list (funcl tannot) -> value -> list (lenv * bool * (exp tannot))
+let rec find_funcl t_level funcls value =
match funcls with
| [] -> []
| (FCL_aux (FCL_Funcl id pat exp) _)::funcls ->
- let (is_matching,used_unknown,env) = match_pattern pat value in
+ let (is_matching,used_unknown,env) = match_pattern t_level pat value in
if (is_matching && used_unknown)
- then (env,used_unknown,exp)::(find_funcl funcls value)
+ then (env,used_unknown,exp)::(find_funcl t_level funcls value)
else if is_matching then [(env,used_unknown,exp)]
- else find_funcl funcls value
+ else find_funcl t_level funcls value
end
(*see above comment*)
-val find_case : list (pexp tannot) -> value -> list (lenv * bool * (exp tannot))
-let rec find_case pexps value =
+val find_case : top_level -> list (pexp tannot) -> value -> list (lenv * bool * (exp tannot))
+let rec find_case t_level pexps value =
match pexps with
| [] -> []
| (Pat_aux (Pat_exp p e) _)::pexps ->
- let (is_matching,used_unknown,env) = match_pattern p value in
+ let (is_matching,used_unknown,env) = match_pattern t_level p value in
if (is_matching && used_unknown)
- then (env,used_unknown,e)::find_case pexps value
+ then (env,used_unknown,e)::find_case t_level pexps value
else if is_matching then [(env,used_unknown,e)]
- else find_case pexps value
+ else find_case t_level pexps value
end
val interp_main : interp_mode -> top_level -> lenv -> lmem -> (exp tannot) -> (outcome * lmem * lenv)
@@ -1349,7 +1362,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) =
resolve_outcome
(interp_main mode t_level l_env l_mem exp)
(fun v lm le ->
- match find_case pats v with
+ match find_case t_level pats v with
| [] -> (Error l ("No matching patterns in case for value " ^ (string_of_value v)),lm,le)
| [(env,_,exp)] ->
if mode.eager_eval
@@ -1740,7 +1753,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) =
| Tag_global ->
(match Map.lookup name fdefs with
| Just(funcls) ->
- (match find_funcl funcls v with
+ (match find_funcl t_level funcls v with
| [] ->
(Error l ("No matching pattern for function " ^ name ^
" on value " ^ (string_of_value v)),l_mem,l_env)
@@ -1765,7 +1778,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) =
| Tag_empty ->
(match Map.lookup name fdefs with
| Just(funcls) ->
- (match find_funcl funcls v with
+ (match find_funcl t_level funcls v with
| [] ->
(Error l (String.stringAppend "No matching pattern for function " name ),l_mem,l_env)
| [(env,used_unknown,exp)] ->
@@ -1784,7 +1797,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) =
| Tag_spec ->
(match Map.lookup name fdefs with
| Just(funcls) ->
- (match find_funcl funcls v with
+ (match find_funcl t_level funcls v with
| [] ->
(Error l (String.stringAppend "No matching pattern for function " name ),l_mem,l_env)
| [(env,used_unknown,exp)] ->
@@ -1851,7 +1864,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) =
(match Map.lookup name fdefs with
| Nothing -> (Error l ("Internal error, no function def for " ^ name),lm,le)
| Just (funcls) ->
- (match find_funcl funcls (V_tuple [lv;rv]) with
+ (match find_funcl t_level funcls (V_tuple [lv;rv]) with
| [] -> (Error l ("No matching pattern for function " ^ name),lm,l_env)
| [(env,used_unknown,exp)] ->
resolve_outcome
@@ -1869,7 +1882,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) =
(match Map.lookup name fdefs with
| Nothing -> (Error l ("Internal error, no function def for " ^ name),lm,le)
| Just (funcls) ->
- (match find_funcl funcls (V_tuple [lv;rv]) with
+ (match find_funcl t_level funcls (V_tuple [lv;rv]) with
| [] -> (Error l ("No matching pattern for function " ^ name),lm,l_env)
| [(env,used_unknown,exp)] ->
resolve_outcome
@@ -1885,7 +1898,7 @@ and interp_main mode t_level l_env l_mem (E_aux exp (l,annot)) =
(match Map.lookup name fdefs with
| Nothing -> (Error l ("No function definition found for " ^ name),lm,le)
| Just (funcls) ->
- (match find_funcl funcls (V_tuple [lv;rv]) with
+ (match find_funcl t_level funcls (V_tuple [lv;rv]) with
| [] -> (Error l ("No matching pattern for function " ^ name),lm,l_env)
| [(env,used_unknown,exp)] ->
resolve_outcome
@@ -2112,7 +2125,7 @@ and create_write_message_or_update mode t_level value l_env l_mem is_top_level (
| V_tuple vs -> V_tuple (vs ++ [value])
| V_lit (L_aux L_unit _) -> value
| v -> V_tuple [v;value] end in
- (match find_funcl funcls new_vals with
+ (match find_funcl t_level funcls new_vals with
| [] -> ((Error l ("No matching pattern for function " ^ name ^
" on value " ^ (string_of_value new_vals)),l_mem,l_env),Nothing)
| [(env,used_unknown,exp)] ->
@@ -2386,7 +2399,7 @@ and interp_letbind mode t_level l_env l_mem (LB_aux lbind (l,annot)) =
| LB_val_explicit t pat exp ->
match (interp_main mode t_level l_env l_mem exp) with
| (Value v,lm,le) ->
- (match match_pattern pat v with
+ (match match_pattern t_level pat v with
| (true,used_unknown,env) -> ((Value (V_lit (L_aux L_unit l)), lm, (union_env env l_env)),Nothing)
| _ -> ((Error l "Pattern in letbind did not match value",lm,le),Nothing) end)
| (Action a s,lm,le) -> ((Action a s,lm,le),(Just (fun e ->(LB_aux (LB_val_explicit t pat e) (l,annot)))))
@@ -2394,7 +2407,7 @@ and interp_letbind mode t_level l_env l_mem (LB_aux lbind (l,annot)) =
| LB_val_implicit pat exp ->
match (interp_main mode t_level l_env l_mem exp) with
| (Value v,lm,le) ->
- (match match_pattern pat v with
+ (match match_pattern t_level pat v with
| (true,used_unknown,env) -> ((Value (V_lit (L_aux L_unit l)), lm, (union_env env l_env)),Nothing)
| _ -> ((Error l "Pattern in letbind did not match value",lm,le),Nothing) end)
| (Action a s,lm,le) -> ((Action a s,lm,le),(Just (fun e -> (LB_aux (LB_val_implicit pat e) (l,annot)))))
diff --git a/src/rewriter.ml b/src/rewriter.ml
index 79f92e04..4fad78e7 100644
--- a/src/rewriter.ml
+++ b/src/rewriter.ml
@@ -117,13 +117,14 @@ let rec rewrite_exp (E_aux (exp,(l,annot))) =
(*TODO should pass d_env into here so that I can look at the abbreviations if there are any here*)
| Tapp("vector",[TA_nexp n1;TA_nexp nw1;TA_ord o1;_]),
Tapp("vector",[TA_nexp n2;TA_nexp nw2;TA_ord o2;_]) ->
- else (match n1.nexp with
+ (match n1.nexp with
| Nconst i1 -> if nexp_eq n1 n2 then new_exp else rewrap (E_cast (t_to_typ t,new_exp))
- | Nadd _ | Nsub -> (match o1.order with
- | O_inc -> new_exp
- | O_dec -> if nexp_one_more_than nw1 n1
- then rewrap (E_cast (Typ_var (Kid_aux (Var "length") Unknown), new_exp))
- else new_exp)
+ | Nadd _ | Nsub _ -> (match o1.order with
+ | Oinc -> new_exp
+ | Odec ->
+ if nexp_one_more_than nw1 n1
+ then rewrap (E_cast (Typ_aux (Typ_var (Kid_aux((Var "length"), Unknown)), Unknown), new_exp))
+ else new_exp)
| _ -> new_exp)
| _ -> new_exp))
| E_internal_exp (l,impl) ->
diff --git a/src/type_check.ml b/src/type_check.ml
index 36cd4fff..8523ce85 100644
--- a/src/type_check.ml
+++ b/src/type_check.ml
@@ -265,6 +265,13 @@ let rec check_pattern envs emp_tag expect_t (P_aux(p,(l,annot))) : ((tannot pat)
then typ_error l ("Constructor " ^ i ^ " expects arguments of type " ^ (t_to_string t) ^ ", found none")
else default
| _ -> raise (Reporting_basic.err_unreachable l "Constructor tannot does not have function type"))
+ | Some(Base((params,t),Enum,cs,ef,bounds)) ->
+ let t,cs,ef,_ = subst params false t cs ef in
+ if conforms_to_t d_env false false t expect_t
+ then
+ let tp,cp = type_consistent (Expr l) d_env Guarantee false t expect_t in
+ (P_aux(P_app(id,[]),(l,tag_annot t Enum)),Envmap.empty,cs@cp,bounds,tp)
+ else default
| _ -> default)
| P_app(id,pats) ->
let i = id_to_string id in