Skip to content

Commit

Permalink
added stride support to windows
Browse files Browse the repository at this point in the history
  • Loading branch information
LazaroHurtado committed Dec 29, 2022
1 parent 0740695 commit 2c0ae32
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 1 deletion.
52 changes: 52 additions & 0 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,58 @@ where
Windows::new(self.view(), window_size)
}

/// Return a window producer and iterable.
///
/// The windows are all distinct views of size `window_size`
/// that fit into the array's shape.
///
/// The stride is ordered by the outermost axis.<br>
/// Hence, a (x₀, x₁, ..., xₙ) stride will be applied to
/// (A₀, A₁, ..., Aₙ) where Aₓ stands for `Axis(x)`.
///
/// This produces all windows that fit within the array for the given stride,
/// assuming the window size is not larger than the array size.
///
/// The produced element is an `ArrayView<A, D>` with exactly the dimension
/// `window_size`.
///
/// Note that passing a stride of only ones is similar to
/// calling [`ArrayBase::windows()`].
///
/// **Panics** if any dimension of `window_size` or `stride` is zero.<br>
/// (**Panics** if `D` is `IxDyn` and `window_size` or `stride` does not match the
/// number of array axes.)
///
/// This is the same illustration found in [`ArrayBase::windows()`],
/// 2×2 windows in a 3×4 array, but now with a (1, 2) stride:
///
/// ```text
/// ──▶ Axis(1)
///
/// │ ┏━━━━━┳━━━━━┱─────┬─────┐ ┌─────┬─────┲━━━━━┳━━━━━┓
/// ▼ ┃ a₀₀ ┃ a₀₁ ┃ │ │ │ │ ┃ a₀₂ ┃ a₀₃ ┃
/// Axis(0) ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫
/// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃
/// ┡━━━━━╇━━━━━╃─────┼─────┤ ├─────┼─────╄━━━━━╇━━━━━┩
/// │ │ │ │ │ │ │ │ │ │
/// └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘
///
/// ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐
/// │ │ │ │ │ │ │ │ │ │
/// ┢━━━━━╈━━━━━╅─────┼─────┤ ├─────┼─────╆━━━━━╈━━━━━┪
/// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃
/// ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫
/// ┃ a₂₀ ┃ a₂₁ ┃ │ │ │ │ ┃ a₂₂ ┃ a₂₃ ┃
/// ┗━━━━━┻━━━━━┹─────┴─────┘ └─────┴─────┺━━━━━┻━━━━━┛
/// ```
pub fn windows_with_stride<E>(&self, window_size: E, stride: E) -> Windows<'_, A, D>
where
E: IntoDimension<Dim = D>,
S: Data,
{
Windows::new_with_stride(self.view(), window_size, stride)
}

/// Returns a producer which traverses over all windows of a given length along an axis.
///
/// The windows are all distinct, possibly-overlapping views. The shape of each window
Expand Down
58 changes: 58 additions & 0 deletions src/iterators/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,64 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> {
}
}
}

pub(crate) fn new_with_stride<E>(a: ArrayView<'a, A, D>, window_size: E, strides: E) -> Self
where
E: IntoDimension<Dim = D>,
{
let window = window_size.into_dimension();
let strides_d = strides.into_dimension();
ndassert!(
a.ndim() == window.ndim(),
concat!(
"Window dimension {} does not match array dimension {} ",
"(with array of shape {:?})"
),
window.ndim(),
a.ndim(),
a.shape()
);
ndassert!(
a.ndim() == strides_d.ndim(),
concat!(
"Stride dimension {} does not match array dimension {} ",
"(with array of shape {:?})"
),
strides_d.ndim(),
a.ndim(),
a.shape()
);
let mut size = a.dim;
for ((sz, &ws), &stride) in size
.slice_mut()
.iter_mut()
.zip(window.slice())
.zip(strides_d.slice())
{
assert_ne!(ws, 0, "window-size must not be zero!");
assert_ne!(stride, 0, "stride cannot have a dimension as zero!");
// cannot use std::cmp::max(0, ..) since arithmetic underflow panics
*sz = if *sz < ws {
0
} else {
((*sz - (ws - 1) - 1) / stride) + 1
};
}
let window_strides = a.strides.clone();

let mut array_strides = a.strides.clone();
for (arr_stride, ix_stride) in array_strides.slice_mut().iter_mut().zip(strides_d.slice()) {
*arr_stride *= ix_stride;
}

unsafe {
Windows {
base: ArrayView::new(a.ptr, size, array_strides),
window,
strides: window_strides,
}
}
}
}

impl_ndproducer! {
Expand Down
72 changes: 71 additions & 1 deletion tests/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fn windows_iterator_zero_size() {
a.windows(Dim((0, 0, 0)));
}

/// Test that verifites that no windows are yielded on oversized window sizes.
/// Test that verifies that no windows are yielded on oversized window sizes.
#[test]
fn windows_iterator_oversized() {
let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap();
Expand Down Expand Up @@ -95,6 +95,76 @@ fn windows_iterator_3d() {
);
}

/// Test that verifies the `Windows` iterator panics when stride has an axis equal to zero.
#[test]
#[should_panic]
fn windows_iterator_stride_axis_zero() {
let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap();
a.windows_with_stride(Dim((2, 2, 2)), Dim((0,2,2)));
}

/// Test that verifies that only first window is yielded when stride is oversized on every axis.
#[test]
fn windows_iterator_only_one_valid_window_for_oversized_stride() {
let a = Array::from_iter(10..135).into_shape((5, 5, 5)).unwrap();
let mut iter = a.windows_with_stride((2, 2, 2), (8, 8, 8)).into_iter(); // (4,3,2) doesn't fit into (3,3,3) => oversized!
itertools::assert_equal(
iter.next(),
Some(arr3(&[[[10, 11], [15, 16]],[[35, 36], [40, 41]]]))
);
}

/// Simple test for iterating 1d-arrays via `Windows` with stride.
#[test]
fn windows_iterator_1d_with_stride() {
let a = Array::from_iter(10..20).into_shape(10).unwrap();
itertools::assert_equal(
a.windows_with_stride(Dim(4), Dim(2)),
vec![
arr1(&[10, 11, 12, 13]),
arr1(&[12, 13, 14, 15]),
arr1(&[14, 15, 16, 17]),
arr1(&[16, 17, 18, 19]),
],
);
}

/// Simple test for iterating 2d-arrays via `Windows` with stride.
#[test]
fn windows_iterator_2d_with_stride() {
let a = Array::from_iter(10..30).into_shape((5, 4)).unwrap();
itertools::assert_equal(
a.windows_with_stride(Dim((3, 2)), Dim((2,1))),
vec![
arr2(&[[10, 11], [14, 15], [18, 19]]),
arr2(&[[11, 12], [15, 16], [19, 20]]),
arr2(&[[12, 13], [16, 17], [20, 21]]),
arr2(&[[18, 19], [22, 23], [26, 27]]),
arr2(&[[19, 20], [23, 24], [27, 28]]),
arr2(&[[20, 21], [24, 25], [28, 29]]),
],
);
}

/// Simple test for iterating 3d-arrays via `Windows` with stride.
#[test]
fn windows_iterator_3d_with_stride() {
let a = Array::from_iter(10..74).into_shape((4, 4, 4)).unwrap();
itertools::assert_equal(
a.windows_with_stride(Dim((2, 2, 2)), Dim((2,2,2))),
vec![
arr3(&[[[10, 11], [14, 15]], [[26, 27], [30, 31]]]),
arr3(&[[[12, 13], [16, 17]], [[28, 29], [32, 33]]]),
arr3(&[[[18, 19], [22, 23]], [[34, 35], [38, 39]]]),
arr3(&[[[20, 21], [24, 25]], [[36, 37], [40, 41]]]),
arr3(&[[[42, 43], [46, 47]], [[58, 59], [62, 63]]]),
arr3(&[[[44, 45], [48, 49]], [[60, 61], [64, 65]]]),
arr3(&[[[50, 51], [54, 55]], [[66, 67], [70, 71]]]),
arr3(&[[[52, 53], [56, 57]], [[68, 69], [72, 73]]]),
],
);
}

#[test]
fn test_window_zip() {
let a = Array::from_iter(0..64).into_shape((4, 4, 4)).unwrap();
Expand Down

0 comments on commit 2c0ae32

Please sign in to comment.