aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPierre-Marie Pédrot2018-10-25 13:31:53 +0200
committerPierre-Marie Pédrot2018-10-29 17:47:15 +0100
commitd5e762723bca7cb9297183e4332e0a9c7c0932f0 (patch)
tree9eb7b8bf161c210a9c5b676e97fd09ff742daf0d
parent9ead21a38feae29fdde11344de326de86bfe8ad9 (diff)
Do not compare the type arguments in pattern-match branches.
We know that the two are living in a common type, so that it is useless to perform the comparison check. Note that we only use this fast-path when the branches are only made of lambda-abstractions, but this covers all actual cases.
-rw-r--r--clib/cArray.ml12
-rw-r--r--clib/cArray.mli2
-rw-r--r--kernel/reduction.ml26
3 files changed, 37 insertions, 3 deletions
diff --git a/clib/cArray.ml b/clib/cArray.ml
index 9644834381..c3a693ff16 100644
--- a/clib/cArray.ml
+++ b/clib/cArray.ml
@@ -35,6 +35,8 @@ sig
val fold_left_i : (int -> 'a -> 'b -> 'a) -> 'a -> 'b array -> 'a
val fold_right2 :
('a -> 'b -> 'c -> 'c) -> 'a array -> 'b array -> 'c -> 'c
+ val fold_right3 :
+ ('a -> 'b -> 'c -> 'd -> 'd) -> 'a array -> 'b array -> 'c array -> 'd -> 'd
val fold_left2 :
('a -> 'b -> 'c -> 'a) -> 'a -> 'b array -> 'c array -> 'a
val fold_left3 :
@@ -252,6 +254,16 @@ let fold_left2_i f a v1 v2 =
if Array.length v2 <> lv1 then invalid_arg "Array.fold_left2_i";
fold a 0
+let fold_right3 f v1 v2 v3 a =
+ let lv1 = Array.length v1 in
+ let rec fold a n =
+ if n=0 then a
+ else
+ let k = n-1 in
+ fold (f (uget v1 k) (uget v2 k) (uget v3 k) a) k in
+ if Array.length v2 <> lv1 || Array.length v3 <> lv1 then invalid_arg "Array.fold_right3";
+ fold a lv1
+
let fold_left3 f a v1 v2 v3 =
let lv1 = Array.length v1 in
let rec fold a n =
diff --git a/clib/cArray.mli b/clib/cArray.mli
index e65a56d15e..21479d2b45 100644
--- a/clib/cArray.mli
+++ b/clib/cArray.mli
@@ -58,6 +58,8 @@ sig
val fold_left_i : (int -> 'a -> 'b -> 'a) -> 'a -> 'b array -> 'a
val fold_right2 :
('a -> 'b -> 'c -> 'c) -> 'a array -> 'b array -> 'c -> 'c
+ val fold_right3 :
+ ('a -> 'b -> 'c -> 'd -> 'd) -> 'a array -> 'b array -> 'c array -> 'd -> 'd
val fold_left2 :
('a -> 'b -> 'c -> 'a) -> 'a -> 'b array -> 'c array -> 'a
val fold_left3 :
diff --git a/kernel/reduction.ml b/kernel/reduction.ml
index 7c8b1193ab..5515ff9767 100644
--- a/kernel/reduction.ml
+++ b/kernel/reduction.ml
@@ -288,6 +288,14 @@ let conv_table_key infos k1 k2 cuniv =
| RelKey n, RelKey n' when Int.equal n n' -> cuniv
| _ -> raise NotConvertible
+exception IrregularPatternShape
+
+let rec skip_pattern n c =
+ if Int.equal n 0 then c
+ else match kind c with
+ | Lambda (_, _, c) -> skip_pattern (pred n) c
+ | _ -> raise IrregularPatternShape
+
type conv_tab = {
cnv_inf : clos_infos;
lft_tab : clos_tab;
@@ -624,9 +632,21 @@ and convert_vect l2r infos lft1 lft2 v1 v2 cuniv =
fold 0 cuniv
else raise NotConvertible
-and convert_branches l2r infos _ci e1 e2 lft1 lft2 br1 br2 cuniv =
- let fold c1 c2 cuniv = ccnv CONV l2r infos lft1 lft2 (mk_clos e1 c1) (mk_clos e2 c2) cuniv in
- Array.fold_right2 fold br1 br2 cuniv
+and convert_branches l2r infos ci e1 e2 lft1 lft2 br1 br2 cuniv =
+ (** Skip comparison of the pattern types. We know that the two terms are
+ living in a common type, thus this check is useless. *)
+ let fold n c1 c2 cuniv = match skip_pattern n c1, skip_pattern n c2 with
+ | (c1, c2) ->
+ let lft1 = el_liftn n lft1 in
+ let lft2 = el_liftn n lft2 in
+ let e1 = subs_liftn n e1 in
+ let e2 = subs_liftn n e2 in
+ ccnv CONV l2r infos lft1 lft2 (mk_clos e1 c1) (mk_clos e2 c2) cuniv
+ | exception IrregularPatternShape ->
+ (** Might happen due to a shape invariant that is not enforced *)
+ ccnv CONV l2r infos lft1 lft2 (mk_clos e1 c1) (mk_clos e2 c2) cuniv
+ in
+ Array.fold_right3 fold ci.ci_cstr_nargs br1 br2 cuniv
let clos_gen_conv trans cv_pb l2r evars env univs t1 t2 =
let reds = CClosure.RedFlags.red_add_transparent betaiotazeta trans in