From 2cabecf85096e4dd921f0f5017158c2e2479ffbe Mon Sep 17 00:00:00 2001 From: Benjamin Kay Date: Wed, 3 May 2023 12:15:03 -0500 Subject: [PATCH] Make mapv_into_any() work for ArcArray, resolves #1280 --- src/impl_methods.rs | 71 ++++++++++++++++++++++++++++++++++++++------- tests/array.rs | 22 ++++++++++++-- 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4a00ea000..688d21b64 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -17,6 +17,7 @@ use std::mem::{size_of, ManuallyDrop}; use crate::imp_prelude::*; use crate::argument_traits::AssignElem; +use crate::data_traits::RawDataSubst; use crate::dimension; use crate::dimension::broadcast::co_broadcast; use crate::dimension::reshape_dim; @@ -2814,15 +2815,58 @@ where /// map is performed as in [`mapv`]. /// /// Elements are visited in arbitrary order. - /// + /// + /// Example: + /// + /// ```rust + /// # use ndarray::{array, Array}; + /// let a: Array = array![[1., 2., 3.]]; + /// let b = a.clone(); + /// // Same type, no new memory allocation. + /// let a_plus_one = a.mapv_into_any(|a| a + 1.); + /// // Different types, allocates new memory. + /// let rounded = b.mapv_into_any(|a| a.round() as i32); + /// ``` + /// + /// Note that this method works on arrays with different memory + /// representations (e.g. [`OwnedRepr`](crate::OwnedRepr) vs + /// [`OwnedArcRepr`](crate::OwnedArcRepr)) but it does *not* convert between + /// different memory representations. + /// + /// This compiles: + /// ```rust + /// # use ndarray::{array, ArcArray}; + /// let a: ArcArray = array![[1., 2., 3.]].into(); + /// // OwnedArcRepr --> OwnedArcRepr. + /// let a_plus_one = a.mapv_into_any(|a| a + 1.); + /// // We can convert to OwnedRepr if we want. + /// let a_plus_one = a_plus_one.into_owned(); + /// ``` + /// + /// This fails to compile: + /// ```compile_fail,E0308 + /// # use ndarray::{array, Array, ArcArray}; + /// let a: ArcArray = array![[1., 2., 3.]].into(); + /// // OwnedArcRepr --> OwnedRepr + /// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.); + /// ``` + /// /// [`mapv_into`]: ArrayBase::mapv_into /// [`mapv`]: ArrayBase::mapv - pub fn mapv_into_any(self, mut f: F) -> Array + pub fn mapv_into_any(self, mut f: F) -> ArrayBase<>::Output, D> where - S: DataMut, + // Output is same memory representation as input, + // Substituting B for A. + S: DataMut + RawDataSubst, + // Mapping function maps from A to B. F: FnMut(A) -> B, + // Need 'static lifetime bounds for TypeId to work. + // mapv() requires that A be Clone. A: Clone + 'static, B: 'static, + // mapv() always returns ArrayBase,_> + // This bound ensures we can convert from OwnedRepr to the output repr. + ArrayBase<>::Output, D>: From>, { if core::any::TypeId::of::() == core::any::TypeId::of::() { // A and B are the same type. @@ -2832,16 +2876,23 @@ where // Safe because A and B are the same type. unsafe { unlimited_transmute::(b) } }; - // Delegate to mapv_into() using the wrapped closure. - // Convert output to a uniquely owned array of type Array. - let output = self.mapv_into(f).into_owned(); - // Change the return type from Array to Array. - // Again, safe because A and B are the same type. - unsafe { unlimited_transmute::, Array>(output) } + // Delegate to mapv_into() to map from element type A to type A. + let output = self.mapv_into(f); + // // Convert from S's data storage to T's data storage. + // // Suppose `T is `OwnedRepr`. + // // Then `>::Output` is `OwnedRepr`. + // let output: ArrayBase<>::Output, D> = output.into(); + // // Since A == B and T stores elements of type B, it should be true + // // that >::Output == T. + // // Verify that this is indeed the case. + // assert!(core::any::TypeId::of::<>::Output>() == core::any::TypeId::of::()); + // Now we can safely transmute the element type from A to the + // identical type B, keeping the same data storage. + unsafe { unlimited_transmute::, ArrayBase<>::Output, D>>(output) } } else { // A and B are not the same type. // Fallback to mapv(). - self.mapv(f) + self.mapv(f).into() } } diff --git a/tests/array.rs b/tests/array.rs index 696904dab..6efd73514 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -1054,7 +1054,8 @@ fn mapv_into_any_same_type() { let a: Array = array![[1., 2., 3.], [4., 5., 6.]]; let a_plus_one: Array = array![[2., 3., 4.], [5., 6., 7.]]; - assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one); + let b: Array<_, _> = a.mapv_into_any(|a| a + 1.); + assert_eq!(b, a_plus_one); } #[test] @@ -1062,7 +1063,24 @@ fn mapv_into_any_diff_types() { let a: Array = array![[1., 2., 3.], [4., 5., 6.]]; let a_even: Array = array![[false, true, false], [true, false, true]]; - assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even); + let b= a.mapv_into_any(|a| a.round() as i32 % 2 == 0); + assert_eq!(b, a_even); +} + +#[test] +fn mapv_into_any_arcarray_same_type() { + let a: ArcArray = array![[1., 2., 3.], [4., 5., 6.]].into_shared(); + let a_plus_one: Array = array![[2., 3., 4.], [5., 6., 7.]]; + let b = a.mapv_into_any(|a| a + 1.); + assert_eq!(b, a_plus_one); +} + +#[test] +fn mapv_into_any_arcarray_diff_types() { + let a: ArcArray = array![[1., 2., 3.], [4., 5., 6.]].into_shared(); + let a_even: Array = array![[false, true, false], [true, false, true]]; + let b: ArcArray<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0); + assert_eq!(b, a_even); } #[test]