summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBrian Campbell2018-01-23 15:53:46 +0000
committerBrian Campbell2018-01-25 18:59:07 +0000
commit131af111908e354814536eb49b09481a22267577 (patch)
tree4502c79b7404b44a45707fa5b2918e95739924f6 /src
parent215aaf33512dbe44a65589a7f7491d46df3b88e6 (diff)
Basic support for match x[5 .. 2] with case splits
Diffstat (limited to 'src')
-rw-r--r--src/monomorphise.ml99
1 files changed, 67 insertions, 32 deletions
diff --git a/src/monomorphise.ml b/src/monomorphise.ml
index 51c5d473..e72529b4 100644
--- a/src/monomorphise.ml
+++ b/src/monomorphise.ml
@@ -796,7 +796,7 @@ type split =
| NoSplit
| VarSplit of (tannot pat * (* pattern for this case *)
(id * tannot Ast.exp) list * (* substitutions for arguments *)
- (Parse_ast.l * Parse_ast.l) list) (* optional locations of case expressions to reduce *)
+ (Parse_ast.l * int) list) (* optional locations of case expressions to reduce *)
list
| ConstrSplit of (tannot pat * nexp KBindings.t) list
@@ -872,16 +872,19 @@ let rec remove_pat_bindings p =
(* Use the location pairs in choices to reduce case expressions at the first
location to the given case at the second. *)
+(* TODO bound variables! *)
let apply_pat_choices choices =
let rewrite_case (e,cases) =
match List.assoc (exp_loc e) choices with
| choice ->
- let rec find = function
- | (Pat_aux (Pat_exp (p,E_aux (e,_)),_))::_ when pat_loc p = choice -> e
- | _::t -> find t
- | _ -> raise (Reporting_basic.err_unreachable choice
- "Unable to find case I found earlier!")
- in find cases
+ (match List.nth cases choice with
+ | Pat_aux (Pat_exp (p,E_aux (e,_)),_) -> e
+ | Pat_aux (Pat_when _,(l,_)) ->
+ raise (Reporting_basic.err_unreachable l
+ "Pattern acquired a guard after analysis!")
+ | exception Not_found ->
+ raise (Reporting_basic.err_unreachable (exp_loc e)
+ "Unable to find case I found earlier!"))
| exception Not_found -> E_case (e,cases)
in
let open Rewriter in
@@ -1454,17 +1457,8 @@ let split_defs continue_anyway splits defs =
| None -> None
| Some None -> Some (split id l annot)
| Some (Some (pats,l)) ->
- Some (List.map (fun p ->
- let l' = pat_loc p in
- if l' = Parse_ast.Unknown then
- (Reporting_basic.print_error
- (Reporting_basic.Err_general
- (l', "No location for pattern: " ^ string_of_pat p));
- (* If we don't have a location then attempt to continue
- without specialising the original case expression *)
- P_aux (P_as (remove_pat_bindings p,id),(l,annot)),[],[])
- else
- P_aux (P_as (remove_pat_bindings p,id),(l,annot)),[],[l,l'])
+ Some (List.mapi (fun i p ->
+ P_aux (P_as (remove_pat_bindings p,id),(l,annot)),[],[l,i])
pats)
)
| P_app (id,ps) ->
@@ -1984,6 +1978,13 @@ let rewrite_size_parameters env (Defs defs) =
end
+let is_id env id =
+ let ids = Env.get_overloads (Id_aux (id,Parse_ast.Unknown)) env in
+ let ids = id :: List.map (fun (Id_aux (id,_)) -> id) ids in
+ fun (Id_aux (x,_)) -> List.mem x ids
+
+
+
module Analysis =
struct
@@ -2231,19 +2232,48 @@ let deps_of_uvar kid_deps arg_deps = function
| U_effect _ -> dempty
| U_typ typ -> deps_of_typ kid_deps arg_deps typ
+let mk_subrange_pattern vannot vstart vend =
+ let (_,len,ord,typ) = vector_typ_args_of (typ_of_annot vannot) in
+ match ord with
+ | Ord_aux (Ord_var _,_) -> None
+ | Ord_aux (ord',_) ->
+ let vstart,vend = if ord' = Ord_inc then vstart,vend else vend,vstart
+ in
+ let dummyl = Generated Unknown in
+ match len with
+ | Nexp_aux (Nexp_constant len,_) ->
+ Some (fun pat ->
+ let end_len = Big_int.pred (Big_int.sub len vend) in
+ (* Wrap pat in its type; in particular the type checker won't
+ manage P_wild in the middle of a P_vector_concat *)
+ let pat = P_aux (P_typ (pat_typ_of pat, pat),(Generated (pat_loc pat),None)) in
+ let pats = if Big_int.greater end_len Big_int.zero then
+ [pat;P_aux (P_typ (vector_typ (nconstant end_len) ord typ,
+ P_aux (P_wild,(dummyl,None))),(dummyl,None))]
+ else [pat]
+ in
+ let pats = if Big_int.greater vstart Big_int.zero then
+ (P_aux (P_typ (vector_typ (nconstant vstart) ord typ,
+ P_aux (P_wild,(dummyl,None))),(dummyl,None)))::pats
+ else pats
+ in
+ let pats = if ord' = Ord_inc then pats else List.rev pats
+ in
+ P_aux (P_vector_concat pats,(Generated (fst vannot),None)))
+ | _ -> None
+
(* If the expression matched on in a case expression is a function argument,
and has no other dependencies, we can try to use the pattern match directly
rather than doing a full case split. *)
let refine_dependency env (E_aux (e,(l,annot)) as exp) pexps =
- match e with
- | E_id id ->
- (match Bindings.find id env.var_deps with
- | Have (args,callargs,callkids) ->
- if CallerArgSet.is_empty callargs && CallerKidSet.is_empty callkids then
- match ArgSplits.bindings args with
- | [(id',loc),Total] when Id.compare id id' == 0 ->
+ let check_dep id ctx =
+ match Bindings.find id env.var_deps with
+ | Have (args,callargs,callkids) ->
+ if CallerArgSet.is_empty callargs && CallerKidSet.is_empty callkids then
+ match ArgSplits.bindings args with
+ | [(id',loc),Total] when Id.compare id id' == 0 ->
(match Util.map_all (function
- | Pat_aux (Pat_exp (pat,_),_) -> Some pat
+ | Pat_aux (Pat_exp (pat,_),_) -> Some (ctx pat)
| Pat_aux (Pat_when (_,_,_),_) -> None) pexps
with
| Some pats ->
@@ -2258,7 +2288,17 @@ let refine_dependency env (E_aux (e,(l,annot)) as exp) pexps =
| _ -> None
else None
| Unknown _ -> None
- | exception Not_found -> None)
+ | exception Not_found -> None
+ in
+ match e with
+ | E_id id -> check_dep id (fun x -> x)
+ | E_app (fn_id, [E_aux (E_id id,vannot);
+ E_aux (E_lit (L_aux (L_num vstart,_)),_);
+ E_aux (E_lit (L_aux (L_num vend,_)),_)])
+ when is_id (env_of exp) (Id "vector_subrange") fn_id ->
+ (match mk_subrange_pattern vannot vstart vend with
+ | Some mk_pat -> check_dep id mk_pat
+ | None -> None)
| _ -> None
(* Takes an environment of dependencies on vars, type vars, and flow control,
@@ -2726,11 +2766,6 @@ let is_constant_vec_typ env typ =
| _ -> false)
| _ -> false
-let is_id env id =
- let ids = Env.get_overloads (Id_aux (id,Parse_ast.Unknown)) env in
- let ids = id :: List.map (fun (Id_aux (id,_)) -> id) ids in
- fun (Id_aux (x,_)) -> List.mem x ids
-
(* We have to add casts in here with appropriate length information so that the
type checker knows the expected return types. *)