Skip to content

Commit

Permalink
ADL-proof implementation of [alg.merge], [alg.set.operations], [alg.h…
Browse files Browse the repository at this point in the history
…eap.operations], and [alg.permutation.generators] (#4347)

Co-authored-by: Stephan T. Lavavej <[email protected]>
  • Loading branch information
frederick-vs-ja and StephanTLavavej authored Jan 30, 2024
1 parent c5e9583 commit b89a780
Show file tree
Hide file tree
Showing 6 changed files with 491 additions and 321 deletions.
546 changes: 278 additions & 268 deletions stl/inc/algorithm

Large diffs are not rendered by default.

68 changes: 35 additions & 33 deletions stl/inc/execution
Original file line number Diff line number Diff line change
Expand Up @@ -3286,26 +3286,27 @@ struct _Static_partitioned_is_heap_until2 {

static void __stdcall _Threadpool_callback(
__std_PTP_CALLBACK_INSTANCE, void* const _Context, __std_PTP_WORK) noexcept /* terminates */ {
_Run_available_chunked_work(*static_cast<_Static_partitioned_is_heap_until2*>(_Context));
_STD _Run_available_chunked_work(*static_cast<_Static_partitioned_is_heap_until2*>(_Context));
}
};

_EXPORT_STD template <class _ExPo, class _RanIt, class _Pr, _Enable_if_execution_policy_t<_ExPo> /* = 0 */>
_NODISCARD _RanIt is_heap_until(_ExPo&&, _RanIt _First, _RanIt _Last, _Pr _Pred) noexcept /* terminates */ {
// find extent of range that is a heap
_REQUIRE_PARALLEL_ITERATOR(_RanIt);
_Adl_verify_range(_First, _Last);
const auto _UFirst = _Get_unwrapped(_First);
const auto _ULast = _Get_unwrapped(_Last);
_STD _Adl_verify_range(_First, _Last);
const auto _UFirst = _STD _Get_unwrapped(_First);
const auto _ULast = _STD _Get_unwrapped(_Last);
if constexpr (remove_reference_t<_ExPo>::_Parallelize) {
const size_t _Hw_threads = __std_parallel_algorithms_hw_threads();
if (_Hw_threads > 1) { // parallelize on multiprocessor machines
const auto _Count = _ULast - _UFirst;
if (_Count >= 3) { // ... with at least 3 elements
_TRY_BEGIN
_Static_partitioned_is_heap_until2 _Operation{_UFirst, _ULast, _Hw_threads, _Count, _Pass_fn(_Pred)};
_Run_chunked_parallel_work(_Hw_threads, _Operation);
_Seek_wrapped(_First, _Operation._Results._Get_result());
_Static_partitioned_is_heap_until2 _Operation{
_UFirst, _ULast, _Hw_threads, _Count, _STD _Pass_fn(_Pred)};
_STD _Run_chunked_parallel_work(_Hw_threads, _Operation);
_STD _Seek_wrapped(_First, _Operation._Results._Get_result());
return _First;
_CATCH(const _Parallelism_resources_exhausted&)
// fall through to serial case below
Expand All @@ -3314,7 +3315,7 @@ _NODISCARD _RanIt is_heap_until(_ExPo&&, _RanIt _First, _RanIt _Last, _Pr _Pred)
}
}

_Seek_wrapped(_First, _STD is_heap_until(_UFirst, _ULast, _Pass_fn(_Pred)));
_STD _Seek_wrapped(_First, _STD is_heap_until(_UFirst, _ULast, _STD _Pass_fn(_Pred)));
return _First;
}

Expand Down Expand Up @@ -3765,7 +3766,7 @@ struct _Static_partitioned_set_subtraction {
// Get chunk in _Range2 that corresponds to our current chunk from _Range1
auto _Range2_chunk_first = _STD lower_bound(_Range2._First, _Range2._Last, *_Range1_chunk_first, _Pred);
auto _Range2_chunk_last =
_STD upper_bound(_Range2_chunk_first, _Range2._Last, *_Prev_iter(_Range1_chunk_last), _Pred);
_STD upper_bound(_Range2_chunk_first, _Range2._Last, *_STD _Prev_iter(_Range1_chunk_last), _Pred);

// Publish results to rest of chunks.
if (_Chunk_number == 0) {
Expand Down Expand Up @@ -3816,14 +3817,14 @@ struct _Static_partitioned_set_subtraction {

// Place elements from _Range1 in _Dest according to the offsets previously calculated.
auto _Chunk_specific_dest = _Dest + static_cast<_Iter_diff_t<_RanIt3>>(_Prev_chunk_sum);
_Place_elements_from_indices(
_STD _Place_elements_from_indices(
_Range1_chunk_first, _Chunk_specific_dest, _Index_chunk_first, static_cast<ptrdiff_t>(_Num_results));
return _Cancellation_status::_Running;
}

static void __stdcall _Threadpool_callback(
__std_PTP_CALLBACK_INSTANCE, void* const _Context, __std_PTP_WORK) noexcept /* terminates */ {
_Run_available_chunked_work(*static_cast<_Static_partitioned_set_subtraction*>(_Context));
_STD _Run_available_chunked_work(*static_cast<_Static_partitioned_set_subtraction*>(_Context));
}
};

Expand Down Expand Up @@ -3877,13 +3878,13 @@ _FwdIt3 set_intersection(_ExPo&&, _FwdIt1 _First1, _FwdIt1 _Last1, _FwdIt2 _Firs
_REQUIRE_PARALLEL_ITERATOR(_FwdIt1);
_REQUIRE_PARALLEL_ITERATOR(_FwdIt2);
_REQUIRE_CPP17_MUTABLE_ITERATOR(_FwdIt3);
_Adl_verify_range(_First1, _Last1);
_Adl_verify_range(_First2, _Last2);
auto _UFirst1 = _Get_unwrapped(_First1);
const auto _ULast1 = _Get_unwrapped(_Last1);
auto _UFirst2 = _Get_unwrapped(_First2);
const auto _ULast2 = _Get_unwrapped(_Last2);
auto _UDest = _Get_unwrapped_unverified(_Dest);
_STD _Adl_verify_range(_First1, _Last1);
_STD _Adl_verify_range(_First2, _Last2);
auto _UFirst1 = _STD _Get_unwrapped(_First1);
const auto _ULast1 = _STD _Get_unwrapped(_Last1);
auto _UFirst2 = _STD _Get_unwrapped(_First2);
const auto _ULast2 = _STD _Get_unwrapped(_Last2);
auto _UDest = _STD _Get_unwrapped_unverified(_Dest);
using _Diff = _Common_diff_t<_FwdIt1, _FwdIt2, _FwdIt3>;
if constexpr (remove_reference_t<_ExPo>::_Parallelize && _Is_ranges_random_iter_v<_FwdIt1>
&& _Is_ranges_random_iter_v<_FwdIt2> && _Is_cpp17_random_iter_v<_FwdIt3>) {
Expand All @@ -3895,10 +3896,10 @@ _FwdIt3 set_intersection(_ExPo&&, _FwdIt1 _First1, _FwdIt1 _Last1, _FwdIt2 _Firs
if (_Count1 >= 2 && _Count2 >= 2) { // ... with each range containing at least 2 elements
_TRY_BEGIN
_Static_partitioned_set_subtraction _Operation(_Hw_threads, _Count1, _UFirst1, _UFirst2, _ULast2,
_UDest, _Pass_fn(_Pred), _Set_intersection_per_chunk());
_Run_chunked_parallel_work(_Hw_threads, _Operation);
_UDest, _STD _Pass_fn(_Pred), _Set_intersection_per_chunk());
_STD _Run_chunked_parallel_work(_Hw_threads, _Operation);
_UDest += static_cast<_Iter_diff_t<_FwdIt3>>(_Operation._Lookback.back()._Sum._Ref());
_Seek_wrapped(_Dest, _UDest);
_STD _Seek_wrapped(_Dest, _UDest);
return _Dest;
_CATCH(const _Parallelism_resources_exhausted&)
// fall through to serial case below
Expand All @@ -3907,7 +3908,8 @@ _FwdIt3 set_intersection(_ExPo&&, _FwdIt1 _First1, _FwdIt1 _Last1, _FwdIt2 _Firs
}
}

_Seek_wrapped(_Dest, _STD set_intersection(_UFirst1, _ULast1, _UFirst2, _ULast2, _UDest, _Pass_fn(_Pred)));
_STD _Seek_wrapped(
_Dest, _STD set_intersection(_UFirst1, _ULast1, _UFirst2, _ULast2, _UDest, _STD _Pass_fn(_Pred)));
return _Dest;
}

Expand Down Expand Up @@ -3967,13 +3969,13 @@ _FwdIt3 set_difference(_ExPo&&, _FwdIt1 _First1, _FwdIt1 _Last1, _FwdIt2 _First2
_REQUIRE_PARALLEL_ITERATOR(_FwdIt1);
_REQUIRE_PARALLEL_ITERATOR(_FwdIt2);
_REQUIRE_CPP17_MUTABLE_ITERATOR(_FwdIt3);
_Adl_verify_range(_First1, _Last1);
_Adl_verify_range(_First2, _Last2);
auto _UFirst1 = _Get_unwrapped(_First1);
const auto _ULast1 = _Get_unwrapped(_Last1);
auto _UFirst2 = _Get_unwrapped(_First2);
const auto _ULast2 = _Get_unwrapped(_Last2);
auto _UDest = _Get_unwrapped_unverified(_Dest);
_STD _Adl_verify_range(_First1, _Last1);
_STD _Adl_verify_range(_First2, _Last2);
auto _UFirst1 = _STD _Get_unwrapped(_First1);
const auto _ULast1 = _STD _Get_unwrapped(_Last1);
auto _UFirst2 = _STD _Get_unwrapped(_First2);
const auto _ULast2 = _STD _Get_unwrapped(_Last2);
auto _UDest = _STD _Get_unwrapped_unverified(_Dest);
using _Diff = _Common_diff_t<_FwdIt1, _FwdIt2, _FwdIt3>;
if constexpr (remove_reference_t<_ExPo>::_Parallelize && _Is_ranges_random_iter_v<_FwdIt1>
&& _Is_ranges_random_iter_v<_FwdIt2> && _Is_cpp17_random_iter_v<_FwdIt3>) {
Expand All @@ -3984,10 +3986,10 @@ _FwdIt3 set_difference(_ExPo&&, _FwdIt1 _First1, _FwdIt1 _Last1, _FwdIt2 _First2
if (_Count >= 2) { // ... with at least 2 elements in [_First1, _Last1)
_TRY_BEGIN
_Static_partitioned_set_subtraction _Operation(_Hw_threads, _Count, _UFirst1, _UFirst2, _ULast2, _UDest,
_Pass_fn(_Pred), _Set_difference_per_chunk());
_Run_chunked_parallel_work(_Hw_threads, _Operation);
_STD _Pass_fn(_Pred), _Set_difference_per_chunk());
_STD _Run_chunked_parallel_work(_Hw_threads, _Operation);
_UDest += static_cast<_Iter_diff_t<_FwdIt3>>(_Operation._Lookback.back()._Sum._Ref());
_Seek_wrapped(_Dest, _UDest);
_STD _Seek_wrapped(_Dest, _UDest);
return _Dest;
_CATCH(const _Parallelism_resources_exhausted&)
// fall through to serial case below
Expand All @@ -3996,7 +3998,7 @@ _FwdIt3 set_difference(_ExPo&&, _FwdIt1 _First1, _FwdIt1 _Last1, _FwdIt2 _First2
}
}

_Seek_wrapped(_Dest, _STD set_difference(_UFirst1, _ULast1, _UFirst2, _ULast2, _UDest, _Pass_fn(_Pred)));
_STD _Seek_wrapped(_Dest, _STD set_difference(_UFirst1, _ULast1, _UFirst2, _ULast2, _UDest, _STD _Pass_fn(_Pred)));
return _Dest;
}

Expand Down
22 changes: 12 additions & 10 deletions stl/inc/xmemory
Original file line number Diff line number Diff line change
Expand Up @@ -1636,13 +1636,13 @@ struct _NODISCARD _Uninitialized_backout {
_Uninitialized_backout& operator=(const _Uninitialized_backout&) = delete;

_CONSTEXPR20 ~_Uninitialized_backout() {
_Destroy_range(_First, _Last);
_STD _Destroy_range(_First, _Last);
}

template <class... _Types>
_CONSTEXPR20 void _Emplace_back(_Types&&... _Vals) {
// construct a new element at *_Last and increment
_Construct_in_place(*_Last, _STD forward<_Types>(_Vals)...);
_STD _Construct_in_place(*_Last, _STD forward<_Types>(_Vals)...);
++_Last;
}

Expand All @@ -1660,7 +1660,7 @@ _CONSTEXPR20 _NoThrowFwdIt _Uninitialized_move_unchecked(_InIt _First, const _In
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
return _Copy_memmove(_First, _Last, _Dest);
return _STD _Copy_memmove(_First, _Last, _Dest);
}
}
_Uninitialized_backout<_NoThrowFwdIt> _Backout{_Dest};
Expand Down Expand Up @@ -1746,10 +1746,10 @@ namespace ranges {
template <class _InIt, class _OutIt>
in_out_result<_InIt, _OutIt> _Copy_memcpy_common(
_InIt _IFirst, _InIt _ILast, _OutIt _OFirst, _OutIt _OLast) noexcept {
const auto _IFirstPtr = _To_address(_IFirst);
const auto _ILastPtr = _To_address(_ILast);
const auto _OFirstPtr = _To_address(_OFirst);
const auto _OLastPtr = _To_address(_OLast);
const auto _IFirstPtr = _STD _To_address(_IFirst);
const auto _ILastPtr = _STD _To_address(_ILast);
const auto _OFirstPtr = _STD _To_address(_OFirst);
const auto _OLastPtr = _STD _To_address(_OLast);
const auto _IFirst_ch = const_cast<char*>(reinterpret_cast<const volatile char*>(_IFirstPtr));
const auto _ILast_ch = const_cast<const char*>(reinterpret_cast<const volatile char*>(_ILastPtr));
const auto _OFirst_ch = const_cast<char*>(reinterpret_cast<const volatile char*>(_OFirstPtr));
Expand Down Expand Up @@ -1783,12 +1783,14 @@ namespace ranges {
if constexpr (_Iter_move_cat<_It, _Out>::_Bitcopy_constructible && _Sized_or_unreachable_sentinel_for<_Se, _It>
&& _Sized_or_unreachable_sentinel_for<_OSe, _Out>) {
if constexpr (_Is_sized1 && _Is_sized2) {
return _Copy_memcpy_common(_IFirst, _RANGES next(_IFirst, _STD move(_ILast)), _OFirst,
return _RANGES _Copy_memcpy_common(_IFirst, _RANGES next(_IFirst, _STD move(_ILast)), _OFirst,
_RANGES next(_OFirst, _STD move(_OLast)));
} else if constexpr (_Is_sized1) {
return _Copy_memcpy_distance(_IFirst, _OFirst, _IFirst, _RANGES next(_IFirst, _STD move(_ILast)));
return _RANGES _Copy_memcpy_distance(
_IFirst, _OFirst, _IFirst, _RANGES next(_IFirst, _STD move(_ILast)));
} else if constexpr (_Is_sized2) {
return _Copy_memcpy_distance(_IFirst, _OFirst, _OFirst, _RANGES next(_OFirst, _STD move(_OLast)));
return _RANGES _Copy_memcpy_distance(
_IFirst, _OFirst, _OFirst, _RANGES next(_OFirst, _STD move(_OLast)));
} else {
_STL_ASSERT(false, "Tried to uninitialized_move two ranges with unreachable sentinels");
}
Expand Down
20 changes: 10 additions & 10 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -6978,22 +6978,22 @@ namespace ranges {
_EXPORT_STD template <class _FwdIt, class _Ty, class _Pr>
_NODISCARD _CONSTEXPR20 _FwdIt lower_bound(_FwdIt _First, const _FwdIt _Last, const _Ty& _Val, _Pr _Pred) {
// find first element not before _Val
_Adl_verify_range(_First, _Last);
auto _UFirst = _Get_unwrapped(_First);
_Iter_diff_t<_FwdIt> _Count = _STD distance(_UFirst, _Get_unwrapped(_Last));
_STD _Adl_verify_range(_First, _Last);
auto _UFirst = _STD _Get_unwrapped(_First);
_Iter_diff_t<_FwdIt> _Count = _STD distance(_UFirst, _STD _Get_unwrapped(_Last));

while (0 < _Count) { // divide and conquer, find half that contains answer
const _Iter_diff_t<_FwdIt> _Count2 = _Count / 2;
const auto _UMid = _STD next(_UFirst, _Count2);
if (_Pred(*_UMid, _Val)) { // try top half
_UFirst = _Next_iter(_UMid);
_UFirst = _STD _Next_iter(_UMid);
_Count -= _Count2 + 1;
} else {
_Count = _Count2;
}
}

_Seek_wrapped(_First, _UFirst);
_STD _Seek_wrapped(_First, _UFirst);
return _First;
}

Expand All @@ -7006,22 +7006,22 @@ _NODISCARD _CONSTEXPR20 _FwdIt lower_bound(_FwdIt _First, _FwdIt _Last, const _T
_EXPORT_STD template <class _FwdIt, class _Ty, class _Pr>
_NODISCARD _CONSTEXPR20 _FwdIt upper_bound(_FwdIt _First, _FwdIt _Last, const _Ty& _Val, _Pr _Pred) {
// find first element that _Val is before
_Adl_verify_range(_First, _Last);
auto _UFirst = _Get_unwrapped(_First);
_Iter_diff_t<_FwdIt> _Count = _STD distance(_UFirst, _Get_unwrapped(_Last));
_STD _Adl_verify_range(_First, _Last);
auto _UFirst = _STD _Get_unwrapped(_First);
_Iter_diff_t<_FwdIt> _Count = _STD distance(_UFirst, _STD _Get_unwrapped(_Last));

while (0 < _Count) { // divide and conquer, find half that contains answer
_Iter_diff_t<_FwdIt> _Count2 = _Count / 2;
const auto _UMid = _STD next(_UFirst, _Count2);
if (_Pred(_Val, *_UMid)) {
_Count = _Count2;
} else { // try top half
_UFirst = _Next_iter(_UMid);
_UFirst = _STD _Next_iter(_UMid);
_Count -= _Count2 + 1;
}
}

_Seek_wrapped(_First, _UFirst);
_STD _Seek_wrapped(_First, _UFirst);
return _First;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,48 @@ void test_algorithms() {
// (void) std::shift_right(varr, varr, 0); // requires Cpp17ValueSwappable
#endif // _HAS_CXX20

int iarr3[1]{};
validator varr3[1]{};

(void) std::merge(varr, varr, varr2, varr2, varr3);
(void) std::merge(iarr, iarr, iarr2, iarr2, iarr3, validating_less{});

// std::inplace_merge(varr, varr, varr); // requires Cpp17ValueSwappable
std::inplace_merge(iarr, iarr, iarr, validating_less{});

(void) std::includes(varr, varr, varr, varr);
(void) std::includes(iarr, iarr, iarr, iarr, validating_less{});

(void) std::set_union(varr, varr, varr, varr, varr3);
(void) std::set_union(iarr, iarr, iarr, iarr, iarr3, validating_less{});

(void) std::set_intersection(varr, varr, varr, varr, varr3);
(void) std::set_intersection(iarr, iarr, iarr, iarr, iarr3, validating_less{});

(void) std::set_difference(varr, varr, varr, varr, varr3);
(void) std::set_difference(iarr, iarr, iarr, iarr, iarr3, validating_less{});

(void) std::set_symmetric_difference(varr, varr, varr, varr, varr3);
(void) std::set_symmetric_difference(iarr, iarr, iarr, iarr, iarr3, validating_less{});

std::push_heap(varr3, varr3 + 1); // requires Cpp17ValueSwappable
std::push_heap(iarr3, iarr3 + 1, validating_less{});

std::pop_heap(varr3, varr3 + 1); // requires Cpp17ValueSwappable
std::pop_heap(iarr3, iarr3 + 1, validating_less{});

std::make_heap(varr3, varr3 + 1); // requires Cpp17ValueSwappable
std::make_heap(iarr3, iarr3 + 1, validating_less{});

std::sort_heap(varr3, varr3 + 1); // requires Cpp17ValueSwappable
std::sort_heap(iarr3, iarr3 + 1, validating_less{});

(void) std::is_heap(varr3, varr3 + 1);
(void) std::is_heap(iarr3, iarr3 + 1, validating_less{});

(void) std::is_heap_until(varr3, varr3 + 1);
(void) std::is_heap_until(iarr3, iarr3 + 1, validating_less{});

(void) std::min(+varr, +varr);
(void) std::min(+iarr, +iarr, validating_less{});
(void) std::min({+varr, +varr});
Expand Down Expand Up @@ -299,6 +341,12 @@ void test_algorithms() {
(void) std::lexicographical_compare_three_way(varr, varr, varr, varr);
(void) std::lexicographical_compare_three_way(iarr, iarr, iarr, iarr, validating_compare_three_way{});
#endif // _HAS_CXX20 && defined(__cpp_lib_concepts)

// (void) std::next_permutation(varr, varr); // requires Cpp17ValueSwappable
(void) std::next_permutation(iarr, iarr, validating_less{});

// (void) std::prev_permutation(varr, varr); // requires Cpp17ValueSwappable
(void) std::prev_permutation(iarr, iarr, validating_less{});
}

#if _HAS_CXX17
Expand Down Expand Up @@ -431,6 +479,36 @@ void test_per_execution_policy() {
// (void) std::shift_right(ExecutionPolicy, varr, varr, 0); // requires Cpp17ValueSwappable
#endif // _HAS_CXX20

int iarr3[2]{};
validator varr3[2]{};

(void) std::merge(ExecutionPolicy, varr, varr, varr2, varr2, varr3);
(void) std::merge(ExecutionPolicy, iarr, iarr, iarr2, iarr2, iarr3, validating_less{});

// std::inplace_merge(ExecutionPolicy, varr, varr, varr); // requires Cpp17ValueSwappable
std::inplace_merge(ExecutionPolicy, iarr, iarr, iarr, validating_less{});

(void) std::includes(ExecutionPolicy, varr, varr, varr, varr);
(void) std::includes(ExecutionPolicy, iarr, iarr, iarr, iarr, validating_less{});

(void) std::set_union(ExecutionPolicy, varr, varr, varr, varr, varr3);
(void) std::set_union(ExecutionPolicy, iarr, iarr, iarr, iarr, iarr3, validating_less{});

(void) std::set_intersection(ExecutionPolicy, varr, varr, varr, varr, varr3);
(void) std::set_intersection(ExecutionPolicy, iarr, iarr, iarr, iarr, iarr3, validating_less{});

(void) std::set_difference(ExecutionPolicy, varr, varr, varr, varr, varr3);
(void) std::set_difference(ExecutionPolicy, iarr, iarr, iarr, iarr, iarr3, validating_less{});

(void) std::set_symmetric_difference(ExecutionPolicy, varr, varr, varr, varr, varr3);
(void) std::set_symmetric_difference(ExecutionPolicy, iarr, iarr, iarr, iarr, iarr3, validating_less{});

(void) std::is_heap(ExecutionPolicy, varr3, varr3 + 1);
(void) std::is_heap(ExecutionPolicy, iarr3, iarr3 + 1, validating_less{});

(void) std::is_heap_until(ExecutionPolicy, varr3, varr3 + 1);
(void) std::is_heap_until(ExecutionPolicy, iarr3, iarr3 + 1, validating_less{});

(void) std::min_element(ExecutionPolicy, varr, varr + 1);
(void) std::min_element(ExecutionPolicy, iarr, iarr + 1, validating_less{});

Expand Down
Loading

0 comments on commit b89a780

Please sign in to comment.