Skip to content

Commit

Permalink
Make mapv_into_any() work for ArcArray, resolves rust-ndarray#1280
Browse files Browse the repository at this point in the history
  • Loading branch information
benkay86 committed Oct 30, 2024
1 parent 492b274 commit 10ba041
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 12 deletions.
69 changes: 59 additions & 10 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::mem::{size_of, ManuallyDrop};
use crate::imp_prelude::*;

use crate::argument_traits::AssignElem;
use crate::data_traits::RawDataSubst;
use crate::dimension;
use crate::dimension::broadcast::co_broadcast;
use crate::dimension::reshape_dim;
Expand Down Expand Up @@ -2814,15 +2815,59 @@ where
/// map is performed as in [`mapv`].
///
/// Elements are visited in arbitrary order.
///
///
/// Example:
///
/// ```rust
/// # use ndarray::{array, Array};
/// let a: Array<f32, _> = array![[1., 2., 3.]];
/// let b = a.clone();
/// // Same type, no new memory allocation.
/// let a_plus_one = a.mapv_into_any(|a| a + 1.);
/// // Different types, allocates new memory.
/// let rounded = b.mapv_into_any(|a| a.round() as i32);
/// ```
///
/// Note that this method works on arrays with different memory
/// representations (e.g. [`OwnedRepr`](crate::OwnedRepr) vs
/// [`OwnedArcRepr`](crate::OwnedArcRepr)) but it does *not* convert between
/// different memory representations.
///
/// This compiles:
/// ```rust
/// # use ndarray::{array, ArcArray};
/// let a: ArcArray<f32, _> = array![[1., 2., 3.]].into();
/// // OwnedArcRepr --> OwnedArcRepr.
/// let a_plus_one = a.mapv_into_any(|a| a + 1.);
/// // We can convert to OwnedRepr if we want.
/// let a_plus_one = a_plus_one.into_owned();
/// ```
///
/// This fails to compile:
/// ```compile_fail,E0308
/// # use ndarray::{array, Array, ArcArray};
/// let a: ArcArray<f32, _> = array![[1., 2., 3.]].into();
/// // OwnedArcRepr --> OwnedRepr
/// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.);
/// ```
///
/// [`mapv_into`]: ArrayBase::mapv_into
/// [`mapv`]: ArrayBase::mapv
pub fn mapv_into_any<B, F>(self, mut f: F) -> Array<B, D>
pub fn mapv_into_any<B, F>(self, mut f: F) -> ArrayBase<<S as RawDataSubst<B>>::Output, D>
where
S: DataMut,
// Output is same memory representation as input,
// Substituting B for A.
// Need 'static lifetime bounds for TypeId to work.
S: DataMut<Elem = A> + RawDataSubst<B> + 'static,
// Mapping function maps from A to B.
F: FnMut(A) -> B,
// Need 'static lifetime bounds for TypeId to work.
// mapv() requires that A be Clone.
A: Clone + 'static,
B: 'static,
// mapv() always returns ArrayBase<OwnedRepr<_>,_>
// This bound ensures we can convert from OwnedRepr to the output repr.
ArrayBase<<S as RawDataSubst<B>>::Output, D>: From<Array<B,D>>,
{
if core::any::TypeId::of::<A>() == core::any::TypeId::of::<B>() {
// A and B are the same type.
Expand All @@ -2832,16 +2877,20 @@ where
// Safe because A and B are the same type.
unsafe { unlimited_transmute::<B, A>(b) }
};
// Delegate to mapv_into() using the wrapped closure.
// Convert output to a uniquely owned array of type Array<A, D>.
let output = self.mapv_into(f).into_owned();
// Change the return type from Array<A, D> to Array<B, D>.
// Again, safe because A and B are the same type.
unsafe { unlimited_transmute::<Array<A, D>, Array<B, D>>(output) }
// Delegate to mapv_into() to map from element type A to type A.
let output = self.mapv_into(f);
// If A and B are the same type, and if the input and output arrays
// have the same kind of memory representation (OwnedRepr vs
// OwnedArcRepr), then their memory representations should be the
// same type, e.g. OwnedRepr<A> == OwnedRepr<B>
debug_assert!(core::any::TypeId::of::<S>() == core::any::TypeId::of::<<S as RawDataSubst<B>>::Output>());
// Now we can safely transmute the element type from A to the
// identical type B, keeping the same memory representation.
unsafe { unlimited_transmute::<ArrayBase<S, D>, ArrayBase<<S as RawDataSubst<B>>::Output, D>>(output) }
} else {
// A and B are not the same type.
// Fallback to mapv().
self.mapv(f)
self.mapv(f).into()
}
}

Expand Down
38 changes: 36 additions & 2 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1054,15 +1054,49 @@ fn mapv_into_any_same_type()
{
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one);
let b = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
fn mapv_into_any_diff_types()
{
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even);
let b = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
assert_eq!(b, a_even);
}

#[test]
fn mapv_into_any_arcarray_same_type() {
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
let b = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
fn mapv_into_any_arcarray_diff_types() {
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
let b = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
assert_eq!(b, a_even);
}

#[test]
fn mapv_into_any_cowarray_same_type() {
let a: CowArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into();
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
let b = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
fn mapv_into_any_cowarray_diff_types() {
let a: CowArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into();
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
let b = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
assert_eq!(b, a_even);
}

#[test]
Expand Down

0 comments on commit 10ba041

Please sign in to comment.