aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPierre-Marie Pédrot2020-01-28 15:32:18 +0100
committerPierre-Marie Pédrot2020-01-28 15:32:18 +0100
commit60dd89f3a9b946875cc38678f50bd50431bafa3d (patch)
tree5da4640c79a1059da97185884e3440f64d27ee2e
parentb105077dd42e34f19d0849620fec2837e84b4887 (diff)
parent41b844befc7e7a720510358389e7e84e239404db (diff)
Merge PR #11376: Fix fold order in CArray.fold_right(2)_map
Reviewed-by: ppedrot
-rw-r--r--clib/cArray.ml32
-rw-r--r--clib/cArray.mli2
2 files changed, 23 insertions, 11 deletions
diff --git a/clib/cArray.ml b/clib/cArray.ml
index be59ae57d0..0f57204cc1 100644
--- a/clib/cArray.ml
+++ b/clib/cArray.ml
@@ -392,18 +392,30 @@ let iter2_i f v1 v2 =
let () = if not (Int.equal len2 len1) then invalid_arg "Array.iter2" in
for i = 0 to len1 - 1 do f i (uget v1 i) (uget v2 i) done
-let pure_functional = false
+let map_right f a =
+ let l = length a in
+ if l = 0 then [||] else begin
+ let r = Array.make l (f (unsafe_get a (l-1))) in
+ for i = l-2 downto 0 do
+ unsafe_set r i (f (unsafe_get a i))
+ done;
+ r
+ end
+
+let map2_right f a b =
+ let l = length a in
+ if l <> length b then invalid_arg "CArray.map2_right: length mismatch";
+ if l = 0 then [||] else begin
+ let r = Array.make l (f (unsafe_get a (l-1)) (unsafe_get b (l-1))) in
+ for i = l-2 downto 0 do
+ unsafe_set r i (f (unsafe_get a i) (unsafe_get b i))
+ done;
+ r
+ end
let fold_right_map f v e =
-if pure_functional then
- let (l,e) =
- Array.fold_right
- (fun x (l,e) -> let (y,e) = f x e in (y::l,e))
- v ([],e) in
- (Array.of_list l,e)
-else
let e' = ref e in
- let v' = Array.map (fun x -> let (y,e) = f x !e' in e' := e; y) v in
+ let v' = map_right (fun x -> let (y,e) = f x !e' in e' := e; y) v in
(v',!e')
let fold_left_map f e v =
@@ -414,7 +426,7 @@ let fold_left_map f e v =
let fold_right2_map f v1 v2 e =
let e' = ref e in
let v' =
- map2 (fun x1 x2 -> let (y,e) = f x1 x2 !e' in e' := e; y) v1 v2
+ map2_right (fun x1 x2 -> let (y,e) = f x1 x2 !e' in e' := e; y) v1 v2
in
(v',!e')
diff --git a/clib/cArray.mli b/clib/cArray.mli
index f94af26515..94390a369f 100644
--- a/clib/cArray.mli
+++ b/clib/cArray.mli
@@ -107,7 +107,7 @@ sig
(** Same than [fold_left2_map] but passing the index of the array *)
val fold_right2_map : ('a -> 'b -> 'c -> 'd * 'c) -> 'a array -> 'b array -> 'c -> 'd array * 'c
- (** Same with two arrays, folding on the left *)
+ (** Same with two arrays, folding on the right *)
val distinct : 'a array -> bool
(** Return [true] if every element of the array is unique (for default