diff options
| -rw-r--r-- | src/pattern_completeness.ml | 27 | ||||
| -rw-r--r-- | test/typecheck/pass/single_union.sail | 10 |
2 files changed, 25 insertions, 12 deletions
diff --git a/src/pattern_completeness.ml b/src/pattern_completeness.ml index afda3e1a..381b3923 100644 --- a/src/pattern_completeness.ml +++ b/src/pattern_completeness.ml @@ -87,6 +87,18 @@ let is_wild = function | GP_wild -> true | _ -> false +let check_ctors_complete ctx ctors = + if Bindings.for_all (fun _ gpat -> is_wild gpat) ctors then + let ids = IdSet.of_list (List.map fst (Bindings.bindings ctors)) in + let enums = List.map snd (Bindings.bindings ctx.enums) in + let variants = List.map snd (Bindings.bindings ctx.variants) in + if List.exists (fun ids' -> IdSet.equal ids ids') (enums @ variants) then + GP_wild + else + GP_app ctors + else + GP_app ctors + let rec generalize ctx (P_aux (p_aux, (l, _)) as pat) = match p_aux with | P_lit (L_aux (L_unit, _)) -> @@ -105,7 +117,7 @@ let rec generalize ctx (P_aux (p_aux, (l, _)) as pat) = | Local (Immutable, _) -> GP_wild | Register _ | Local (Mutable, _) -> Reporting.warn "Matching on register or mutable variable at " l ""; GP_wild - | Enum _ -> GP_app (Bindings.singleton id GP_wild) + | Enum _ -> check_ctors_complete ctx (Bindings.singleton id GP_wild) end | P_var (pat, _) -> generalize ctx pat | P_vector pats -> @@ -130,7 +142,7 @@ let rec generalize ctx (P_aux (p_aux, (l, _)) as pat) = | P_app (f, pats) -> let gpats = List.map (generalize ctx) pats in if List.for_all is_wild gpats then - GP_app (Bindings.singleton f GP_wild) + check_ctors_complete ctx (Bindings.singleton f GP_wild) else GP_app (Bindings.singleton f (GP_tup gpats)) @@ -237,16 +249,7 @@ let rec join ctx gpat1 gpat2 = | Some args1, Some args2 -> Some (join ctx args1 args2) in let ctors = Bindings.merge ctor_merge ctors1 ctors2 in - if Bindings.for_all (fun _ gpat -> is_wild gpat) ctors then - let ids = IdSet.of_list (List.map fst (Bindings.bindings ctors)) in - let enums = List.map snd (Bindings.bindings ctx.enums) in - let variants = List.map snd (Bindings.bindings ctx.variants) in - if List.exists (fun ids' -> IdSet.equal ids ids') (enums @ variants) then - GP_wild - else - GP_app ctors - else - GP_app ctors + check_ctors_complete ctx ctors | GP_or (gpat1, gpat2), gpat3 -> join ctx (join ctx gpat1 gpat2) gpat3 | gpat1, GP_or (gpat2, gpat3) -> join ctx gpat1 (join ctx gpat2 gpat3) diff --git a/test/typecheck/pass/single_union.sail b/test/typecheck/pass/single_union.sail new file mode 100644 index 00000000..73eeee50 --- /dev/null +++ b/test/typecheck/pass/single_union.sail @@ -0,0 +1,10 @@ +union foo = { BAR : int } + +val f : foo -> int +function f BAR(x) = x + +val g : foo -> int +function g x = match x { BAR(y) => y } + +val h : int -> foo +function h x = BAR(x) |
