From 3dc80815d965b56b9a975dc27229361955bf66fe Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Sun, 27 Jan 2019 21:24:38 -0800 Subject: [PATCH] Add round-to-nearest-even rounding to float2half(). (#368) * Add round-to-nearest-even to float2half(). Disable with -DMSHADOW_HALF_ROUND_TO_EVEN=0 build. * Correct #if guard name. * Fix lint. * Minor syntax fix for MXNet CI. --- mshadow/half.h | 122 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 94 insertions(+), 28 deletions(-) diff --git a/mshadow/half.h b/mshadow/half.h index 75d8e5d0..2dded0a7 100644 --- a/mshadow/half.h +++ b/mshadow/half.h @@ -13,6 +13,12 @@ #include #endif // MSHADOW_USE_F16C +// This flag dictates rounding for the float2half() routine only (used generally on Windows), +// not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even. +#ifndef MSHADOW_HALF_ROUND_TO_NEAREST +#define MSHADOW_HALF_ROUND_TO_NEAREST 1 +#endif + #if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) #define MSHADOW_CUDA_HALF 1 #include @@ -159,12 +165,18 @@ class MSHADOW_ALIGNED(2) half_t { uint32_t ui; }; - static int const shift = 13; + static int const fp16FractionBits = 10; + static int const fp32FractionBits = 23; + static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff + static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000 + static int const shift = fp32FractionBits - fp16FractionBits; // == 13 static int const shiftSign = 16; + static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15) static int32_t const infN = 0x7F800000; // flt32 infinity - static int32_t const maxN = 0x477FE000; // max flt16 normal as a flt32 + static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 + static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16 static int32_t const signN = 0x80000000; // flt32 sign bit static int32_t const infC = infN >> shift; @@ -183,37 +195,91 @@ class MSHADOW_ALIGNED(2) half_t { static int32_t const minD = minC - subC - 1; MSHADOW_XINLINE uint16_t float2half(const float& value) const { - Bits v, s; + Bits v; v.f = value; - uint32_t sign = v.si & signN; - v.si ^= sign; - sign >>= shiftSign; // logical shift - s.si = mulN; - s.si = s.f * v.f; // correct subnormals - v.si ^= (s.si ^ v.si) & -(minN > v.si); - v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); - v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); - v.ui >>= shift; // logical shift - v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); - v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); - return v.ui | sign; + uint32_t sign = v.si & signN; // grab sign bit + v.si ^= sign; // clear sign bit from v + sign >>= shiftSign; // logical shift sign to fp16 position + + if (v.si <= maxZ) { + // Handle eventual zeros here to ensure vshift will not exceed 32 below. + v.ui = 0; + } else if (v.si < minN) { + // Handle denorms + uint32_t exp32 = v.ui >> fp32FractionBits; + int32_t exp16 = exp32 - expAdjust; + // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. + uint32_t vshift = 1 - exp16; + uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); + v.ui = significand >> vshift; + // The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is + // when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional + // bits to the right of the lsb are 1000... (including flt32 significand bits + // that may be lost during the above vshift). The first term below will always + // be true for vshift >=12 (since even the 'hidden bit' has been shifted to the + // right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make + // the proper test of the flt32 significand bits, including those lost during the vshift. +#if MSHADOW_HALF_ROUND_TO_NEAREST == 1 + // Rounding may increase the exponent to 1, but that's OK. + v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; +#endif + } else if (v.si <= maxN) { + // Handle norms +#if MSHADOW_HALF_ROUND_TO_NEAREST == 1 + // Rounding may increase the exponent, possibly creating an inf, but that's OK. + v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; +#endif + v.ui -= expAdjust << fp32FractionBits; + } else if (v.si <= infN) { + v.si = infN; + } else if (v.si < nanN) { + v.si = nanN; + } + + v.ui >>= shift; + return sign | (v.ui & 0x7fff); } + // Same as above routine, except for addition of volatile keyword MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*) - Bits v, s; + Bits v; v.f = value; - uint32_t sign = v.si & signN; - v.si ^= sign; - sign >>= shiftSign; // logical shift - s.si = mulN; - s.si = s.f * v.f; // correct subnormals - v.si ^= (s.si ^ v.si) & -(minN > v.si); - v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); - v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); - v.ui >>= shift; // logical shift - v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); - v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); - return v.ui | sign; + uint32_t sign = v.si & signN; // grab sign bit + v.si ^= sign; // clear sign bit from v + sign >>= shiftSign; // logical shift sign to fp16 position + + if (v.si <= maxZ) { + // Handle eventual zeros here to ensure vshift will not exceed 32 below. + v.ui = 0; + } else if (v.si < minN) { + // Handle denorms + uint32_t exp32 = v.ui >> fp32FractionBits; + int32_t exp16 = exp32 - expAdjust; + // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. + uint32_t vshift = 1 - exp16; + uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); + v.ui = significand >> vshift; +#if MSHADOW_HALF_ROUND_TO_NEAREST == 1 + // Rounding may increase the exponent to 1, but that's OK. + v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; +#endif + } else if (v.si <= maxN) { + // Handle norms +#if MSHADOW_HALF_ROUND_TO_NEAREST == 1 + // Rounding may increase the exponent, possibly creating an inf, but that's OK. + v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; +#endif + v.ui -= expAdjust << fp32FractionBits; + } else if (v.si <= infN) { + v.si = infN; + } else if (v.si < nanN) { + v.si = nanN; + } + + v.ui >>= shift; + return sign | (v.ui & 0x7fff); } MSHADOW_XINLINE float half2float(const uint16_t& value) const {