Skip to content

Commit

Permalink
Merge pull request #986 from rust-ndarray/intoiterator
Browse files Browse the repository at this point in the history
Implement by-value iterator for owned arrays
  • Loading branch information
bluss authored Apr 22, 2021
2 parents 4e31d2f + 5766f4b commit 9f868f7
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 128 deletions.
11 changes: 11 additions & 0 deletions src/data_repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ impl<A> OwnedRepr<A> {
self.ptr.as_ptr()
}

pub(crate) fn as_ptr_mut(&self) -> *mut A {
self.ptr.as_ptr()
}

pub(crate) fn as_nonnull_mut(&mut self) -> NonNull<A> {
self.ptr
}
Expand Down Expand Up @@ -88,6 +92,13 @@ impl<A> OwnedRepr<A> {
self.len = new_len;
}

/// Return the length (number of elements in total)
pub(crate) fn release_all_elements(&mut self) -> usize {
let ret = self.len;
self.len = 0;
ret
}

/// Cast self into equivalent repr of other element type
///
/// ## Safety
Expand Down
22 changes: 22 additions & 0 deletions src/impl_constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::indices;
#[cfg(feature = "std")]
use crate::iterators::to_vec;
use crate::iterators::to_vec_mapped;
use crate::iterators::TrustedIterator;
use crate::StrideShape;
#[cfg(feature = "std")]
use crate::{geomspace, linspace, logspace};
Expand Down Expand Up @@ -495,6 +496,27 @@ where
ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim)
}

/// Creates an array from an iterator, mapped by `map` and interpret it according to the
/// provided shape and strides.
///
/// # Safety
///
/// See from_shape_vec_unchecked
pub(crate) unsafe fn from_shape_trusted_iter_unchecked<Sh, I, F>(shape: Sh, iter: I, map: F)
-> Self
where
Sh: Into<StrideShape<D>>,
I: TrustedIterator + ExactSizeIterator,
F: FnMut(I::Item) -> A,
{
let shape = shape.into();
let dim = shape.dim;
let strides = shape.strides.strides_for_dim(&dim);
let v = to_vec_mapped(iter, map);
Self::from_vec_dim_stride_unchecked(dim, strides, v)
}


/// Create an array with uninitalized elements, shape `shape`.
///
/// The uninitialized elements of type `A` are represented by the type `MaybeUninit<A>`,
Expand Down
24 changes: 11 additions & 13 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ where
///
/// **Panics** if an index is out of bounds or step size is zero.<br>
/// **Panics** if `axis` is out of bounds.
#[must_use = "slice_axis returns an array view with the sliced result"]
pub fn slice_axis(&self, axis: Axis, indices: Slice) -> ArrayView<'_, A, D>
where
S: Data,
Expand All @@ -511,6 +512,7 @@ where
///
/// **Panics** if an index is out of bounds or step size is zero.<br>
/// **Panics** if `axis` is out of bounds.
#[must_use = "slice_axis_mut returns an array view with the sliced result"]
pub fn slice_axis_mut(&mut self, axis: Axis, indices: Slice) -> ArrayViewMut<'_, A, D>
where
S: DataMut,
Expand Down Expand Up @@ -2224,17 +2226,14 @@ where
A: 'a,
S: Data,
{
if let Some(slc) = self.as_slice_memory_order() {
let v = crate::iterators::to_vec_mapped(slc.iter(), f);
unsafe {
ArrayBase::from_shape_vec_unchecked(
unsafe {
if let Some(slc) = self.as_slice_memory_order() {
ArrayBase::from_shape_trusted_iter_unchecked(
self.dim.clone().strides(self.strides.clone()),
v,
)
slc.iter(), f)
} else {
ArrayBase::from_shape_trusted_iter_unchecked(self.dim.clone(), self.iter(), f)
}
} else {
let v = crate::iterators::to_vec_mapped(self.iter(), f);
unsafe { ArrayBase::from_shape_vec_unchecked(self.dim.clone(), v) }
}
}

Expand All @@ -2254,11 +2253,10 @@ where
if self.is_contiguous() {
let strides = self.strides.clone();
let slc = self.as_slice_memory_order_mut().unwrap();
let v = crate::iterators::to_vec_mapped(slc.iter_mut(), f);
unsafe { ArrayBase::from_shape_vec_unchecked(dim.strides(strides), v) }
unsafe { ArrayBase::from_shape_trusted_iter_unchecked(dim.strides(strides),
slc.iter_mut(), f) }
} else {
let v = crate::iterators::to_vec_mapped(self.iter_mut(), f);
unsafe { ArrayBase::from_shape_vec_unchecked(dim, v) }
unsafe { ArrayBase::from_shape_trusted_iter_unchecked(dim, self.iter_mut(), f) }
}
}

Expand Down
151 changes: 78 additions & 73 deletions src/impl_owned_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,89 +223,18 @@ impl<A, D> Array<A, D>
fn drop_unreachable_elements_slow(mut self) -> OwnedRepr<A> {
// "deconstruct" self; the owned repr releases ownership of all elements and we
// carry on with raw view methods
let self_len = self.len();
let data_len = self.data.len();
let data_ptr = self.data.as_nonnull_mut().as_ptr();

let mut self_;

unsafe {
// Safety: self.data releases ownership of the elements. Any panics below this point
// will result in leaking elements instead of double drops.
self_ = self.raw_view_mut();
let self_ = self.raw_view_mut();
self.data.set_len(0);
}


// uninvert axes where needed, so that stride > 0
for i in 0..self_.ndim() {
if self_.stride_of(Axis(i)) < 0 {
self_.invert_axis(Axis(i));
}
drop_unreachable_raw(self_, data_ptr, data_len);
}

// Sort axes to standard order, Axis(0) has biggest stride and Axis(n - 1) least stride
// Note that self_ has holes, so self_ is not C-contiguous
sort_axes_in_default_order(&mut self_);

unsafe {
// with uninverted axes this is now the element with lowest address
let array_memory_head_ptr = self_.ptr.as_ptr();
let data_end_ptr = data_ptr.add(data_len);
debug_assert!(data_ptr <= array_memory_head_ptr);
debug_assert!(array_memory_head_ptr <= data_end_ptr);

// The idea is simply this: the iterator will yield the elements of self_ in
// increasing address order.
//
// The pointers produced by the iterator are those that we *do not* touch.
// The pointers *not mentioned* by the iterator are those we have to drop.
//
// We have to drop elements in the range from `data_ptr` until (not including)
// `data_end_ptr`, except those that are produced by `iter`.

// As an optimization, the innermost axis is removed if it has stride 1, because
// we then have a long stretch of contiguous elements we can skip as one.
let inner_lane_len;
if self_.ndim() > 1 && self_.strides.last_elem() == 1 {
self_.dim.slice_mut().rotate_right(1);
self_.strides.slice_mut().rotate_right(1);
inner_lane_len = self_.dim[0];
self_.dim[0] = 1;
self_.strides[0] = 1;
} else {
inner_lane_len = 1;
}

// iter is a raw pointer iterator traversing the array in memory order now with the
// sorted axes.
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
let mut dropped_elements = 0;

let mut last_ptr = data_ptr;

while let Some(elem_ptr) = iter.next() {
// The interval from last_ptr up until (not including) elem_ptr
// should now be dropped. This interval may be empty, then we just skip this loop.
while last_ptr != elem_ptr {
debug_assert!(last_ptr < data_end_ptr);
std::ptr::drop_in_place(last_ptr);
last_ptr = last_ptr.add(1);
dropped_elements += 1;
}
// Next interval will continue one past the current lane
last_ptr = elem_ptr.add(inner_lane_len);
}

while last_ptr < data_end_ptr {
std::ptr::drop_in_place(last_ptr);
last_ptr = last_ptr.add(1);
dropped_elements += 1;
}

assert_eq!(data_len, dropped_elements + self_len,
"Internal error: inconsistency in move_into");
}
self.data
}

Expand Down Expand Up @@ -594,6 +523,82 @@ impl<A, D> Array<A, D>
}
}

/// This drops all "unreachable" elements in `self_` given the data pointer and data length.
///
/// # Safety
///
/// This is an internal function for use by move_into and IntoIter only, safety invariants may need
/// to be upheld across the calls from those implementations.
pub(crate) unsafe fn drop_unreachable_raw<A, D>(mut self_: RawArrayViewMut<A, D>, data_ptr: *mut A, data_len: usize)
where
D: Dimension,
{
let self_len = self_.len();

for i in 0..self_.ndim() {
if self_.stride_of(Axis(i)) < 0 {
self_.invert_axis(Axis(i));
}
}
sort_axes_in_default_order(&mut self_);
// with uninverted axes this is now the element with lowest address
let array_memory_head_ptr = self_.ptr.as_ptr();
let data_end_ptr = data_ptr.add(data_len);
debug_assert!(data_ptr <= array_memory_head_ptr);
debug_assert!(array_memory_head_ptr <= data_end_ptr);

// The idea is simply this: the iterator will yield the elements of self_ in
// increasing address order.
//
// The pointers produced by the iterator are those that we *do not* touch.
// The pointers *not mentioned* by the iterator are those we have to drop.
//
// We have to drop elements in the range from `data_ptr` until (not including)
// `data_end_ptr`, except those that are produced by `iter`.

// As an optimization, the innermost axis is removed if it has stride 1, because
// we then have a long stretch of contiguous elements we can skip as one.
let inner_lane_len;
if self_.ndim() > 1 && self_.strides.last_elem() == 1 {
self_.dim.slice_mut().rotate_right(1);
self_.strides.slice_mut().rotate_right(1);
inner_lane_len = self_.dim[0];
self_.dim[0] = 1;
self_.strides[0] = 1;
} else {
inner_lane_len = 1;
}

// iter is a raw pointer iterator traversing the array in memory order now with the
// sorted axes.
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
let mut dropped_elements = 0;

let mut last_ptr = data_ptr;

while let Some(elem_ptr) = iter.next() {
// The interval from last_ptr up until (not including) elem_ptr
// should now be dropped. This interval may be empty, then we just skip this loop.
while last_ptr != elem_ptr {
debug_assert!(last_ptr < data_end_ptr);
std::ptr::drop_in_place(last_ptr);
last_ptr = last_ptr.add(1);
dropped_elements += 1;
}
// Next interval will continue one past the current lane
last_ptr = elem_ptr.add(inner_lane_len);
}

while last_ptr < data_end_ptr {
std::ptr::drop_in_place(last_ptr);
last_ptr = last_ptr.add(1);
dropped_elements += 1;
}

assert_eq!(data_len, dropped_elements + self_len,
"Internal error: inconsistency in move_into");
}

/// Sort axes to standard order, i.e Axis(0) has biggest stride and Axis(n - 1) least stride
///
/// The axes should have stride >= 0 before calling this method.
Expand Down
Loading

0 comments on commit 9f868f7

Please sign in to comment.