diff options
| author | Brian Campbell | 2018-01-23 15:53:46 +0000 |
|---|---|---|
| committer | Brian Campbell | 2018-01-25 18:59:07 +0000 |
| commit | 131af111908e354814536eb49b09481a22267577 (patch) | |
| tree | 4502c79b7404b44a45707fa5b2918e95739924f6 /src | |
| parent | 215aaf33512dbe44a65589a7f7491d46df3b88e6 (diff) | |
Basic support for match x[5 .. 2] with case splits
Diffstat (limited to 'src')
| -rw-r--r-- | src/monomorphise.ml | 99 |
1 files changed, 67 insertions, 32 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml index 51c5d473..e72529b4 100644 --- a/src/monomorphise.ml +++ b/src/monomorphise.ml @@ -796,7 +796,7 @@ type split = | NoSplit | VarSplit of (tannot pat * (* pattern for this case *) (id * tannot Ast.exp) list * (* substitutions for arguments *) - (Parse_ast.l * Parse_ast.l) list) (* optional locations of case expressions to reduce *) + (Parse_ast.l * int) list) (* optional locations of case expressions to reduce *) list | ConstrSplit of (tannot pat * nexp KBindings.t) list @@ -872,16 +872,19 @@ let rec remove_pat_bindings p = (* Use the location pairs in choices to reduce case expressions at the first location to the given case at the second. *) +(* TODO bound variables! *) let apply_pat_choices choices = let rewrite_case (e,cases) = match List.assoc (exp_loc e) choices with | choice -> - let rec find = function - | (Pat_aux (Pat_exp (p,E_aux (e,_)),_))::_ when pat_loc p = choice -> e - | _::t -> find t - | _ -> raise (Reporting_basic.err_unreachable choice - "Unable to find case I found earlier!") - in find cases + (match List.nth cases choice with + | Pat_aux (Pat_exp (p,E_aux (e,_)),_) -> e + | Pat_aux (Pat_when _,(l,_)) -> + raise (Reporting_basic.err_unreachable l + "Pattern acquired a guard after analysis!") + | exception Not_found -> + raise (Reporting_basic.err_unreachable (exp_loc e) + "Unable to find case I found earlier!")) | exception Not_found -> E_case (e,cases) in let open Rewriter in @@ -1454,17 +1457,8 @@ let split_defs continue_anyway splits defs = | None -> None | Some None -> Some (split id l annot) | Some (Some (pats,l)) -> - Some (List.map (fun p -> - let l' = pat_loc p in - if l' = Parse_ast.Unknown then - (Reporting_basic.print_error - (Reporting_basic.Err_general - (l', "No location for pattern: " ^ string_of_pat p)); - (* If we don't have a location then attempt to continue - without specialising the original case expression *) - P_aux (P_as (remove_pat_bindings p,id),(l,annot)),[],[]) - else - P_aux (P_as (remove_pat_bindings p,id),(l,annot)),[],[l,l']) + Some (List.mapi (fun i p -> + P_aux (P_as (remove_pat_bindings p,id),(l,annot)),[],[l,i]) pats) ) | P_app (id,ps) -> @@ -1984,6 +1978,13 @@ let rewrite_size_parameters env (Defs defs) = end +let is_id env id = + let ids = Env.get_overloads (Id_aux (id,Parse_ast.Unknown)) env in + let ids = id :: List.map (fun (Id_aux (id,_)) -> id) ids in + fun (Id_aux (x,_)) -> List.mem x ids + + + module Analysis = struct @@ -2231,19 +2232,48 @@ let deps_of_uvar kid_deps arg_deps = function | U_effect _ -> dempty | U_typ typ -> deps_of_typ kid_deps arg_deps typ +let mk_subrange_pattern vannot vstart vend = + let (_,len,ord,typ) = vector_typ_args_of (typ_of_annot vannot) in + match ord with + | Ord_aux (Ord_var _,_) -> None + | Ord_aux (ord',_) -> + let vstart,vend = if ord' = Ord_inc then vstart,vend else vend,vstart + in + let dummyl = Generated Unknown in + match len with + | Nexp_aux (Nexp_constant len,_) -> + Some (fun pat -> + let end_len = Big_int.pred (Big_int.sub len vend) in + (* Wrap pat in its type; in particular the type checker won't + manage P_wild in the middle of a P_vector_concat *) + let pat = P_aux (P_typ (pat_typ_of pat, pat),(Generated (pat_loc pat),None)) in + let pats = if Big_int.greater end_len Big_int.zero then + [pat;P_aux (P_typ (vector_typ (nconstant end_len) ord typ, + P_aux (P_wild,(dummyl,None))),(dummyl,None))] + else [pat] + in + let pats = if Big_int.greater vstart Big_int.zero then + (P_aux (P_typ (vector_typ (nconstant vstart) ord typ, + P_aux (P_wild,(dummyl,None))),(dummyl,None)))::pats + else pats + in + let pats = if ord' = Ord_inc then pats else List.rev pats + in + P_aux (P_vector_concat pats,(Generated (fst vannot),None))) + | _ -> None + (* If the expression matched on in a case expression is a function argument, and has no other dependencies, we can try to use the pattern match directly rather than doing a full case split. *) let refine_dependency env (E_aux (e,(l,annot)) as exp) pexps = - match e with - | E_id id -> - (match Bindings.find id env.var_deps with - | Have (args,callargs,callkids) -> - if CallerArgSet.is_empty callargs && CallerKidSet.is_empty callkids then - match ArgSplits.bindings args with - | [(id',loc),Total] when Id.compare id id' == 0 -> + let check_dep id ctx = + match Bindings.find id env.var_deps with + | Have (args,callargs,callkids) -> + if CallerArgSet.is_empty callargs && CallerKidSet.is_empty callkids then + match ArgSplits.bindings args with + | [(id',loc),Total] when Id.compare id id' == 0 -> (match Util.map_all (function - | Pat_aux (Pat_exp (pat,_),_) -> Some pat + | Pat_aux (Pat_exp (pat,_),_) -> Some (ctx pat) | Pat_aux (Pat_when (_,_,_),_) -> None) pexps with | Some pats -> @@ -2258,7 +2288,17 @@ let refine_dependency env (E_aux (e,(l,annot)) as exp) pexps = | _ -> None else None | Unknown _ -> None - | exception Not_found -> None) + | exception Not_found -> None + in + match e with + | E_id id -> check_dep id (fun x -> x) + | E_app (fn_id, [E_aux (E_id id,vannot); + E_aux (E_lit (L_aux (L_num vstart,_)),_); + E_aux (E_lit (L_aux (L_num vend,_)),_)]) + when is_id (env_of exp) (Id "vector_subrange") fn_id -> + (match mk_subrange_pattern vannot vstart vend with + | Some mk_pat -> check_dep id mk_pat + | None -> None) | _ -> None (* Takes an environment of dependencies on vars, type vars, and flow control, @@ -2726,11 +2766,6 @@ let is_constant_vec_typ env typ = | _ -> false) | _ -> false -let is_id env id = - let ids = Env.get_overloads (Id_aux (id,Parse_ast.Unknown)) env in - let ids = id :: List.map (fun (Id_aux (id,_)) -> id) ids in - fun (Id_aux (x,_)) -> List.mem x ids - (* We have to add casts in here with appropriate length information so that the type checker knows the expected return types. *) |
