Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

<numeric>: check for gcd / lcm overflows #4776

Merged
merged 14 commits into from
Jul 11, 2024
29 changes: 28 additions & 1 deletion stl/inc/numeric
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,19 @@ _NODISCARD constexpr common_type_t<_Mt, _Nt> gcd(const _Mt _Mx, const _Nt _Nx) n
using _Common = common_type_t<_Mt, _Nt>;
using _Common_unsigned = make_unsigned_t<_Common>;

if constexpr (is_signed_v<_Common>) {
#ifndef _DEBUG
if (_STD _Is_constant_evaluated())
#endif // ^^^ !defined(_DEBUG) ^^^
{
constexpr auto _Min_common = _STD _Min_limit<_Common>();
if (_Mx == _Min_common || _Nx == _Min_common) {
_STL_REPORT_ERROR("Preconditions: |m| and |n| are representable as a value of common_type_t<M, N>. "
"(N4981 [numeric.ops.gcd]/2, N4981 [numeric.ops.lcm]/2)");
}
}
}

return _Select_countr_zero_impl<_Common_unsigned>([=](auto _Countr_zero_impl) {
_Common_unsigned _Mx_magnitude = _Abs_u(_Mx);
_Common_unsigned _Nx_magnitude = _Abs_u(_Nx);
Expand Down Expand Up @@ -670,7 +683,21 @@ _NODISCARD constexpr common_type_t<_Mt, _Nt> lcm(const _Mt _Mx, const _Nt _Nx) n
return 0;
}

return static_cast<_Common>((_Mx_magnitude / _STD gcd(_Mx_magnitude, _Nx_magnitude)) * _Nx_magnitude);
#ifndef _DEBUG
if (!_STD _Is_constant_evaluated()) {
return static_cast<_Common>((_Mx_magnitude / _STD gcd(_Mx_magnitude, _Nx_magnitude)) * _Nx_magnitude);
}
#endif // ^^^ !defined(_DEBUG) ^^^

_Common_unsigned _Result = 0;
_Common_unsigned _Tmp = static_cast<_Common_unsigned>(_Mx_magnitude / _STD gcd(_Mx_magnitude, _Nx_magnitude));

if (_Mul_overflow(_Tmp, _Nx_magnitude, _Result) || !_In_range<_Common>(_Result)) {
_STL_REPORT_ERROR("Preconditions: The least common multiple of |m| and |n| is representable as a value of "
"type common_type_t<M, N>. (N4981 [numeric.ops.lcm]/2)");
}

return static_cast<_Common>(_Result);
}
#endif // _HAS_CXX17

Expand Down
31 changes: 21 additions & 10 deletions stl/inc/utility
Original file line number Diff line number Diff line change
Expand Up @@ -784,14 +784,9 @@ _EXPORT_STD template <size_t _Idx>
constexpr in_place_index_t<_Idx> in_place_index{};
#endif // _HAS_CXX17

template <class _Ty>
constexpr bool _Is_standard_integer = _Is_any_of_v<remove_cv_t<_Ty>, signed char, short, int, long, long long,
unsigned char, unsigned short, unsigned int, unsigned long, unsigned long long>;

template <class _Ty1, class _Ty2>
_NODISCARD constexpr bool _Cmp_equal(const _Ty1 _Left, const _Ty2 _Right) noexcept {
static_assert(_Is_standard_integer<_Ty1> && _Is_standard_integer<_Ty2>,
"The integer comparison functions only accept standard and extended integer types.");
_STL_INTERNAL_STATIC_ASSERT(_Is_nonbool_integral<_Ty1> && _Is_nonbool_integral<_Ty2>); // allows character types
if constexpr (is_signed_v<_Ty1> == is_signed_v<_Ty2>) {
return _Left == _Right;
} else if constexpr (is_signed_v<_Ty2>) {
Expand All @@ -808,8 +803,7 @@ _NODISCARD constexpr bool _Cmp_not_equal(const _Ty1 _Left, const _Ty2 _Right) no

template <class _Ty1, class _Ty2>
_NODISCARD constexpr bool _Cmp_less(const _Ty1 _Left, const _Ty2 _Right) noexcept {
static_assert(_Is_standard_integer<_Ty1> && _Is_standard_integer<_Ty2>,
"The integer comparison functions only accept standard and extended integer types.");
_STL_INTERNAL_STATIC_ASSERT(_Is_nonbool_integral<_Ty1> && _Is_nonbool_integral<_Ty2>); // allows character types
if constexpr (is_signed_v<_Ty1> == is_signed_v<_Ty2>) {
return _Left < _Right;
} else if constexpr (is_signed_v<_Ty2>) {
Expand Down Expand Up @@ -858,8 +852,7 @@ _NODISCARD constexpr _Ty _Max_limit() noexcept { // same as (numeric_limits<_Ty>

template <class _Rx, class _Ty>
_NODISCARD constexpr bool _In_range(const _Ty _Value) noexcept {
static_assert(_Is_standard_integer<_Rx> && _Is_standard_integer<_Ty>,
"The integer comparison functions only accept standard and extended integer types.");
_STL_INTERNAL_STATIC_ASSERT(_Is_nonbool_integral<_Rx> && _Is_nonbool_integral<_Ty>); // allows character types

constexpr auto _Ty_min = _Min_limit<_Ty>();
constexpr auto _Rx_min = _Min_limit<_Rx>();
Expand All @@ -883,38 +876,56 @@ _NODISCARD constexpr bool _In_range(const _Ty _Value) noexcept {
}

#if _HAS_CXX20
template <class _Ty>
constexpr bool _Is_standard_integer = _Is_any_of_v<remove_cv_t<_Ty>, signed char, short, int, long, long long,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be useful to comment that this can be used to test for standard and extended integer types, because there are no extended integer types (probably in a subsequent change)

unsigned char, unsigned short, unsigned int, unsigned long, unsigned long long>;

_EXPORT_STD template <class _Ty1, class _Ty2>
_NODISCARD constexpr bool cmp_equal(const _Ty1 _Left, const _Ty2 _Right) noexcept {
static_assert(_Is_standard_integer<_Ty1> && _Is_standard_integer<_Ty2>,
"The integer comparison functions only accept standard and extended integer types.");
return _STD _Cmp_equal(_Left, _Right);
}

_EXPORT_STD template <class _Ty1, class _Ty2>
_NODISCARD constexpr bool cmp_not_equal(const _Ty1 _Left, const _Ty2 _Right) noexcept {
static_assert(_Is_standard_integer<_Ty1> && _Is_standard_integer<_Ty2>,
"The integer comparison functions only accept standard and extended integer types.");
return _STD _Cmp_not_equal(_Left, _Right);
}

_EXPORT_STD template <class _Ty1, class _Ty2>
_NODISCARD constexpr bool cmp_less(const _Ty1 _Left, const _Ty2 _Right) noexcept {
static_assert(_Is_standard_integer<_Ty1> && _Is_standard_integer<_Ty2>,
"The integer comparison functions only accept standard and extended integer types.");
return _STD _Cmp_less(_Left, _Right);
}

_EXPORT_STD template <class _Ty1, class _Ty2>
_NODISCARD constexpr bool cmp_greater(const _Ty1 _Left, const _Ty2 _Right) noexcept {
static_assert(_Is_standard_integer<_Ty1> && _Is_standard_integer<_Ty2>,
"The integer comparison functions only accept standard and extended integer types.");
return _STD _Cmp_greater(_Left, _Right);
}

_EXPORT_STD template <class _Ty1, class _Ty2>
_NODISCARD constexpr bool cmp_less_equal(const _Ty1 _Left, const _Ty2 _Right) noexcept {
static_assert(_Is_standard_integer<_Ty1> && _Is_standard_integer<_Ty2>,
"The integer comparison functions only accept standard and extended integer types.");
return _STD _Cmp_less_equal(_Left, _Right);
}

_EXPORT_STD template <class _Ty1, class _Ty2>
_NODISCARD constexpr bool cmp_greater_equal(const _Ty1 _Left, const _Ty2 _Right) noexcept {
static_assert(_Is_standard_integer<_Ty1> && _Is_standard_integer<_Ty2>,
"The integer comparison functions only accept standard and extended integer types.");
return _STD _Cmp_greater_equal(_Left, _Right);
}

_EXPORT_STD template <class _Rx, class _Ty>
_NODISCARD constexpr bool in_range(const _Ty _Value) noexcept {
static_assert(_Is_standard_integer<_Rx> && _Is_standard_integer<_Ty>,
"The integer comparison functions only accept standard and extended integer types.");
return _STD _In_range<_Rx>(_Value);
}
#endif // _HAS_CXX20
Expand Down
16 changes: 13 additions & 3 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -7372,16 +7372,25 @@ _NODISCARD constexpr bool _Add_overflow(const _Int _Left, const _Int _Right, _In
}
}
}
#endif // _HAS_CXX23

#if _HAS_CXX17
#if _HAS_CXX20
template <_Integer_like _Int>
#else // ^^^ _HAS_CXX20 / !_HAS_CXX20 vvv
template <class _Int, enable_if_t<_Is_nonbool_integral<_Int>, int> = 0>
#endif // ^^^ !_HAS_CXX20 ^^^
_NODISCARD constexpr bool _Mul_overflow(const _Int _Left, const _Int _Right, _Int& _Out) {
#if defined(__clang__) && !_HAS_CXX20
return __builtin_mul_overflow(_Left, _Right, &_Out);
#else // ^^^ defined(__clang__) && !_HAS_CXX20 / !defined(__clang__) || _HAS_CXX20 vvv
#ifdef __clang__
if constexpr (integral<_Int>) {
if constexpr (is_integral_v<_Int>) {
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
return __builtin_mul_overflow(_Left, _Right, &_Out);
} else
#endif // defined(__clang__)
{
if constexpr (!_Signed_integer_like<_Int>) {
if constexpr (static_cast<_Int>(-1) > static_cast<_Int>(0)) {
constexpr auto _UInt_max = _STD _Max_limit<_Int>();
const bool _Overflow = _Left != 0 && _Right > _UInt_max / _Left;
if (!_Overflow) {
Expand Down Expand Up @@ -7420,8 +7429,9 @@ _NODISCARD constexpr bool _Mul_overflow(const _Int _Left, const _Int _Right, _In
// ^^^ Based on llvm::MulOverflow ^^^
}
}
#endif // ^^^ !defined(__clang__) || _HAS_CXX20 ^^^
}
#endif // _HAS_CXX23
#endif // _HAS_CXX17

_STD_END

Expand Down
27 changes: 26 additions & 1 deletion tests/std/tests/P0295R0_gcd_lcm/test.compile.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static_assert(gcd(1073741824, 536870912) == 536870912);
static_assert(gcd(1073741824, -536870912) == 536870912);
static_assert(gcd(-1073741824, 536870912) == 536870912);
static_assert(gcd(int_max, int_max) == int_max);
static_assert(gcd(int_min, int_max) == 1);
// gcd(int_min, int_max) -> undefined behavior
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
// gcd(int_min, int_min) -> undefined behavior
static_assert(gcd(int_min + 1, int_min + 1) == int_max);

Expand All @@ -49,3 +49,28 @@ static_assert(lcm(1, 0) == 0);
static_assert(lcm(1073741824, 536870912) == 1073741824);
static_assert(lcm(1073741824, -536870912) == 1073741824);
static_assert(lcm(-1073741824, 536870912) == 1073741824);

template <class T>
constexpr bool test_nonbool_integral_type() {
static_assert(gcd(T{60}, T{24}) == T{12});
static_assert(lcm(T{60}, T{24}) == T{120});
return true;
}

static_assert(test_nonbool_integral_type<char>());
static_assert(test_nonbool_integral_type<wchar_t>());
#ifdef __cpp_char8_t
static_assert(test_nonbool_integral_type<char8_t>());
#endif // __cpp_char8_t
static_assert(test_nonbool_integral_type<char16_t>());
static_assert(test_nonbool_integral_type<char32_t>());
static_assert(test_nonbool_integral_type<signed char>());
static_assert(test_nonbool_integral_type<short>());
static_assert(test_nonbool_integral_type<int>());
static_assert(test_nonbool_integral_type<long>());
static_assert(test_nonbool_integral_type<long long>());
static_assert(test_nonbool_integral_type<unsigned char>());
static_assert(test_nonbool_integral_type<unsigned short>());
static_assert(test_nonbool_integral_type<unsigned int>());
static_assert(test_nonbool_integral_type<unsigned long>());
static_assert(test_nonbool_integral_type<unsigned long long>());