Skip to content

Commit

Permalink
Make mapv_into_any() work for ArcArray, resolves rust-ndarray#1280
Browse files Browse the repository at this point in the history
  • Loading branch information
benkay86 committed May 3, 2023
1 parent 0740695 commit ce37b94
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
22 changes: 17 additions & 5 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2586,15 +2586,27 @@ where
/// map is performed as in [`mapv`].
///
/// Elements are visited in arbitrary order.
///
///
/// Note that the compiler will need some hint about the return type, which
/// is generic over [`DataOwned`], and can thus be an [`Array`] or
/// [`ArcArray`]. Example:
///
/// ```rust
/// # use ndarray::{array, Array};
/// let a = array![[1., 2., 3.]];
/// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.);
/// ```
///
/// [`mapv_into`]: ArrayBase::mapv_into
/// [`mapv`]: ArrayBase::mapv
pub fn mapv_into_any<B, F>(self, mut f: F) -> Array<B, D>
pub fn mapv_into_any<B, F, T>(self, mut f: F) -> ArrayBase<T, D>
where
S: DataMut,
F: FnMut(A) -> B,
A: Clone + 'static,
B: 'static,
T: DataOwned<Elem = B>,
ArrayBase<T, D>: From<Array<B, D>>,
{
if core::any::TypeId::of::<A>() == core::any::TypeId::of::<B>() {
// A and B are the same type.
Expand All @@ -2606,14 +2618,14 @@ where
};
// Delegate to mapv_into() using the wrapped closure.
// Convert output to a uniquely owned array of type Array<A, D>.
let output = self.mapv_into(f).into_owned();
let output = self.mapv_into(f);
// Change the return type from Array<A, D> to Array<B, D>.
// Again, safe because A and B are the same type.
unsafe { unlimited_transmute::<Array<A, D>, Array<B, D>>(output) }
unsafe { unlimited_transmute::<ArrayBase<S, D>, ArrayBase<T, D>>(output) }
} else {
// A and B are not the same type.
// Fallback to mapv().
self.mapv(f)
self.mapv(f).into()
}
}

Expand Down
22 changes: 20 additions & 2 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -995,14 +995,32 @@ fn map1() {
fn mapv_into_any_same_type() {
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one);
let b: Array<_, _> = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
fn mapv_into_any_diff_types() {
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even);
let b: Array<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
assert_eq!(b, a_even);
}

#[test]
fn mapv_into_any_arcarray_same_type() {
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
let b: ArcArray<_, _> = a.mapv_into(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
fn mapv_into_any_arcarray_diff_types() {
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
let b: ArcArray<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
assert_eq!(b, a_even);
}

#[test]
Expand Down

0 comments on commit ce37b94

Please sign in to comment.