Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vectorize replace 🎭 #4554

Merged
merged 18 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions benchmarks/src/replace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ const char src[] =
"euismod eros, ut posuere ligula ullamcorper id. Nullam aliquam malesuada est at dignissim. Pellentesque finibus "
"sagittis libero nec bibendum. Phasellus dolor ipsum, finibus quis turpis quis, mollis interdum felis.";

template <class T>
void r(benchmark::State& state) {
const std::vector<T> a(std::begin(src), std::end(src));
std::vector<T> b(std::size(src));

for (auto _ : state) {
b = a;
std::replace(std::begin(b), std::end(b), T{'m'}, T{'w'});
}
}

template <class T>
void rc(benchmark::State& state) {
const std::vector<T> a(std::begin(src), std::end(src));
Expand All @@ -58,6 +69,10 @@ void rc_if(benchmark::State& state) {
}
}

// replace() is vectorized for 4 and 8 bytes only.
BENCHMARK(r<std::uint32_t>);
BENCHMARK(r<std::uint64_t>);

BENCHMARK(rc<std::uint8_t>);
BENCHMARK(rc<std::uint16_t>);
BENCHMARK(rc<std::uint32_t>);
Expand Down
68 changes: 68 additions & 0 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ __declspec(noalias) _Min_max_8i __stdcall __std_minmax_8i(const void* _First, co
__declspec(noalias) _Min_max_8u __stdcall __std_minmax_8u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_f __stdcall __std_minmax_f(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_d __stdcall __std_minmax_d(const void* _First, const void* _Last) noexcept;

// TRANSITION, DevCom-10610477
__declspec(noalias) void __stdcall __std_replace_4(
void* _First, void* _Last, uint32_t _Old_val, uint32_t _New_val) noexcept;
__declspec(noalias) void __stdcall __std_replace_8(
void* _First, void* _Last, uint64_t _Old_val, uint64_t _New_val) noexcept;
} // extern "C"

_STD_BEGIN
Expand Down Expand Up @@ -180,6 +186,24 @@ _Ty1* __std_find_first_of_trivial(
}
}

template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
if constexpr (is_pointer_v<_Ty>) {
#ifdef _WIN64
::__std_replace_8(_First, _Last, reinterpret_cast<uint64_t>(_Old_val), reinterpret_cast<uint64_t>(_New_val));
#else // ^^^ defined(_WIN64) / !defined(_WIN64) vvv
::__std_replace_4(_First, _Last, reinterpret_cast<uint32_t>(_Old_val), reinterpret_cast<uint32_t>(_New_val));
#endif // ^^^ !defined(_WIN64) ^^^
} else if constexpr (sizeof(_Ty) == 4) {
::__std_replace_4(_First, _Last, static_cast<uint32_t>(_Old_val), static_cast<uint32_t>(_New_val));
} else if constexpr (sizeof(_Ty) == 8) {
::__std_replace_8(_First, _Last, static_cast<uint64_t>(_Old_val), static_cast<uint64_t>(_New_val));
} else {
static_assert(_Always_false<_Ty>, "Unexpected size");
}
}

// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

Expand All @@ -188,6 +212,17 @@ template <class _It1, class _It2, class _Pr>
_INLINE_VAR constexpr bool _Vector_alg_in_find_first_of_is_safe =
_Equal_memcmp_is_safe<_It1, _It2, _Pr> // can replace value comparison with bitwise comparison
&& sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible size

// Can we activate the vector algorithms for replace?
template <class _Iter, class _Ty1>
constexpr bool _Vector_alg_in_replace_is_safe = _Vector_alg_in_find_is_safe<_Iter, _Ty1> // can search for the value
&& sizeof(_Iter_value_t<_Iter>) >= 4; // avx masked op compatible size

// Can we activate the vector algorithms for ranges::replace?
template <class _Iter, class _Ty1, class _Ty2>
constexpr bool _Vector_alg_in_ranges_replace_is_safe =
_Vector_alg_in_replace_is_safe<_Iter, _Ty1> // can search and replace
&& _Vector_alg_in_find_is_safe_elem<_Ty2, _Iter_value_t<_Iter>>; // replacement fits
_STD_END
#endif // _USE_STD_VECTOR_ALGORITHMS

Expand Down Expand Up @@ -3807,6 +3842,22 @@ _CONSTEXPR20 void replace(const _FwdIt _First, const _FwdIt _Last, const _Ty& _O
_STD _Adl_verify_range(_First, _Last);
auto _UFirst = _STD _Get_unwrapped(_First);
const auto _ULast = _STD _Get_unwrapped(_Last);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_replace_is_safe<decltype(_UFirst), _Ty>) {
#if _HAS_CXX20
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
if (_STD _Could_compare_equal_to_value_type<decltype(_UFirst)>(_Oldval)) {
_STD _Replace_vectorized(_STD _To_address(_UFirst), _STD _To_address(_ULast), _Oldval, _Newval);
}

return;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _UFirst != _ULast; ++_UFirst) {
if (*_UFirst == _Oldval) {
*_UFirst = _Newval;
Expand Down Expand Up @@ -3860,6 +3911,23 @@ namespace ranges {
_STL_INTERNAL_STATIC_ASSERT(indirectly_writable<_It, const _Ty2&>);
_STL_INTERNAL_STATIC_ASSERT(indirect_binary_predicate<ranges::equal_to, projected<_It, _Pj>, const _Ty1*>);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (is_same_v<_Pj, identity> && sized_sentinel_for<_Se, _It>
&& _Vector_alg_in_ranges_replace_is_safe<_It, _Ty1, _Ty2>) {
if (!_STD is_constant_evaluated()) {
const auto _Count = _Last - _First;

if (_STD _Could_compare_equal_to_value_type<_It>(_Oldval)) {
const auto _First_ptr = _STD to_address(_First);
const auto _Last_ptr = _First_ptr + _Count;
_STD _Replace_vectorized(_First_ptr, _Last_ptr, _Oldval, _Newval);
}

return _First + _Count;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _First != _Last; ++_First) {
if (_STD invoke(_Proj, *_First) == _Oldval) {
*_First = _Newval;
Expand Down
42 changes: 23 additions & 19 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5830,30 +5830,34 @@ struct _Vector_alg_in_find_is_safe_object_pointers<_Ty1*, _Ty2*>
// either _Ty1 is the same as _Ty2 (ignoring cv-qualifiers), or one of the two is void
disjunction<is_same<remove_cv_t<_Ty1>, remove_cv_t<_Ty2>>, is_void<_Ty1>, is_void<_Ty2>>> {};

// Can we activate the vector algorithms to find a value in a range of elements?
template <class _Ty, class _Elem>
constexpr bool _Vector_alg_in_find_is_safe_elem = disjunction_v<
#ifdef __cpp_lib_byte
// We're finding a std::byte in a range of std::byte.
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>,
#endif // defined(__cpp_lib_byte)
// We're finding an integer in a range of integers.
// This case is the one that requires careful runtime handling in _Could_compare_equal_to_value_type.
conjunction<is_integral<_Ty>, is_integral<_Elem>>,
// We're finding an (object or function) pointer in a range of pointers of the same type.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>,
// We're finding a nullptr in a range of (object or function) pointers.
conjunction<is_same<_Ty, nullptr_t>, is_pointer<_Elem>>,
// We're finding an object pointer in a range of object pointers, and:
// - One of the pointer types is a cv void*.
// - One of the pointer types is a cv1 U* and the other is a cv2 U*.
_Vector_alg_in_find_is_safe_object_pointers<_Ty, _Elem>>;

// Can we activate the vector algorithms for find/count?
template <class _Iter, class _Ty, class _Elem = _Iter_value_t<_Iter>>
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe =
template <class _Iter, class _Ty>
constexpr bool _Vector_alg_in_find_is_safe =
// The iterator must be contiguous so we can get raw pointers.
_Iterator_is_contiguous<_Iter>
// The iterator must not be volatile.
&& !_Iterator_is_volatile<_Iter>
// And one of the following conditions must be met:
&& disjunction_v<
#ifdef __cpp_lib_byte
// We're finding a std::byte in a range of std::byte.
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>,
#endif // defined(__cpp_lib_byte)
// We're finding an integer in a range of integers.
// This case is the one that requires careful runtime handling in _Could_compare_equal_to_value_type.
conjunction<is_integral<_Ty>, is_integral<_Elem>>,
// We're finding an (object or function) pointer in a range of pointers of the same type.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>,
// We're finding a nullptr in a range of (object or function) pointers.
conjunction<is_same<_Ty, nullptr_t>, is_pointer<_Elem>>,
// We're finding an object pointer in a range of object pointers, and:
// - One of the pointer types is a cv void*.
// - One of the pointer types is a cv1 U* and the other is a cv2 U*.
_Vector_alg_in_find_is_safe_object_pointers<_Ty, _Elem>>;
// The type of the value to find must be compatible with the type of the elements.
&& _Vector_alg_in_find_is_safe_elem<_Ty, _Iter_value_t<_Iter>>;

template <class _InIt, class _Ty>
_NODISCARD constexpr bool _Could_compare_equal_to_value_type(const _Ty& _Val) {
Expand Down
77 changes: 77 additions & 0 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2272,6 +2272,83 @@ __declspec(noalias) size_t
return __std_mismatch_impl<_Find_traits_8, uint64_t>(_First1, _First2, _Count);
}

__declspec(noalias) void __stdcall __std_replace_4(
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
void* _First, void* const _Last, const uint32_t _Old_val, const uint32_t _New_val) noexcept {
#ifndef _M_ARM64EC
if (_Use_avx2()) {
const __m256i _Comparand = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_Old_val));
const __m256i _Replacement = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_New_val));
const size_t _Full_length = _Byte_length(_First, _Last);

void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Full_length & ~size_t{0x1F});

while (_First != _Stop_at) {
const __m256i _Data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First));
const __m256i _Mask = _mm256_cmpeq_epi32(_Comparand, _Data);
_mm256_maskstore_epi32(reinterpret_cast<int*>(_First), _Mask, _Replacement);

_Advance_bytes(_First, 32);
}

if (const size_t _Tail_length = _Full_length & 0x1C; _Tail_length != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Tail_length >> 2);
const __m256i _Data = _mm256_maskload_epi32(reinterpret_cast<const int*>(_First), _Tail_mask);
const __m256i _Mask = _mm256_and_si256(_mm256_cmpeq_epi32(_Comparand, _Data), _Tail_mask);
_mm256_maskstore_epi32(reinterpret_cast<int*>(_First), _Mask, _Replacement);
}
} else
#endif // !defined(_M_ARM64EC)
{
for (auto _Cur = reinterpret_cast<uint32_t*>(_First); _Cur != _Last; ++_Cur) {
if (*_Cur == _Old_val) {
*_Cur = _New_val;
}
}
}
}

__declspec(noalias) void __stdcall __std_replace_8(
void* _First, void* const _Last, const uint64_t _Old_val, const uint64_t _New_val) noexcept {
#ifndef _M_ARM64EC
if (_Use_avx2()) {
#ifdef _WIN64
const __m256i _Comparand = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(_Old_val));
const __m256i _Replacement = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(_New_val));
#else // ^^^ defined(_WIN64) / !defined(_WIN64), workaround, _mm_cvtsi64_si128 does not compile vvv
const __m256i _Comparand = _mm256_set1_epi64x(_Old_val);
const __m256i _Replacement = _mm256_set1_epi64x(_New_val);
#endif // ^^^ !defined(_WIN64) ^^^
const size_t _Full_length = _Byte_length(_First, _Last);

void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Full_length & ~size_t{0x1F});

while (_First != _Stop_at) {
const __m256i _Data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First));
const __m256i _Mask = _mm256_cmpeq_epi64(_Comparand, _Data);
_mm256_maskstore_epi64(reinterpret_cast<long long*>(_First), _Mask, _Replacement);

_Advance_bytes(_First, 32);
}

if (const size_t _Tail_length = _Full_length & 0x18; _Tail_length != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Tail_length >> 2);
const __m256i _Data = _mm256_maskload_epi64(reinterpret_cast<const long long*>(_First), _Tail_mask);
const __m256i _Mask = _mm256_and_si256(_mm256_cmpeq_epi64(_Comparand, _Data), _Tail_mask);
_mm256_maskstore_epi64(reinterpret_cast<long long*>(_First), _Mask, _Replacement);
}
} else
#endif // !defined(_M_ARM64EC)
{
for (auto _Cur = reinterpret_cast<uint64_t*>(_First); _Cur != _Last; ++_Cur) {
if (*_Cur == _Old_val) {
*_Cur = _New_val;
}
}
}
}

} // extern "C"

#ifndef _M_ARM64EC
Expand Down
52 changes: 52 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,52 @@ namespace test_mismatch_sizes_and_alignments {
}
} // namespace test_mismatch_sizes_and_alignments

template <class FwdIt, class T>
void last_known_good_replace(FwdIt first, FwdIt last, const T old_val, const T new_val) {
for (; first != last; ++first) {
if (*first == old_val) {
*first = new_val;
}
}
}

template <class T>
void test_case_replace(const vector<T>& input, T old_val, T new_val) {
vector<T> replaced_actual(input);
vector<T> replaced_expected(input);
replace(replaced_actual.begin(), replaced_actual.end(), old_val, new_val);
last_known_good_replace(replaced_expected.begin(), replaced_expected.end(), old_val, new_val);
assert(replaced_expected == replaced_actual);

#if _HAS_CXX20
vector<T> replaced_actual_r(input);
ranges::replace(replaced_actual_r, old_val, new_val);
assert(replaced_expected == replaced_actual_r);
#endif // _HAS_CXX20
}

template <class T>
void test_replace(mt19937_64& gen) {
using TD = conditional_t<sizeof(T) == 1, int, T>;
uniform_int_distribution<TD> dis(0, 9);
vector<T> input;

input.reserve(dataCount);

{
const T old_val = static_cast<T>(dis(gen));
const T new_val = static_cast<T>(dis(gen));
test_case_replace(input, old_val, new_val);
}

for (size_t i = 0; i != dataCount; ++i) {
input.push_back(static_cast<T>(dis(gen)));
const T old_val = static_cast<T>(dis(gen));
const T new_val = static_cast<T>(dis(gen));
test_case_replace(input, old_val, new_val);
}
}

template <class BidIt>
void last_known_good_reverse(BidIt first, BidIt last) {
for (; first != last && first != --last; ++first) {
Expand Down Expand Up @@ -728,6 +774,12 @@ void test_vector_algorithms(mt19937_64& gen) {
test_mismatch_sizes_and_alignments::test<int>();
test_mismatch_sizes_and_alignments::test<long long>();

// replace() is vectorized for 4 and 8 bytes only.
test_replace<int>(gen);
test_replace<unsigned int>(gen);
test_replace<long long>(gen);
test_replace<unsigned long long>(gen);

test_reverse<char>(gen);
test_reverse<signed char>(gen);
test_reverse<unsigned char>(gen);
Expand Down