Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Commit

Permalink
Add round-to-nearest-even rounding to float2half(). (#368)
Browse files Browse the repository at this point in the history
* Add round-to-nearest-even to float2half().  Disable with -DMSHADOW_HALF_ROUND_TO_EVEN=0 build.

* Correct #if guard name.

* Fix lint.

* Minor syntax fix for MXNet CI.
  • Loading branch information
DickJC123 authored and eric-haibin-lin committed Jan 28, 2019
1 parent 6dc04f7 commit 3dc8081
Showing 1 changed file with 94 additions and 28 deletions.
122 changes: 94 additions & 28 deletions mshadow/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
#include <x86intrin.h>
#endif // MSHADOW_USE_F16C

// This flag dictates rounding for the float2half() routine only (used generally on Windows),
// not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even.
#ifndef MSHADOW_HALF_ROUND_TO_NEAREST
#define MSHADOW_HALF_ROUND_TO_NEAREST 1
#endif

#if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
#define MSHADOW_CUDA_HALF 1
#include <cuda_fp16.h>
Expand Down Expand Up @@ -159,12 +165,18 @@ class MSHADOW_ALIGNED(2) half_t {
uint32_t ui;
};

static int const shift = 13;
static int const fp16FractionBits = 10;
static int const fp32FractionBits = 23;
static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
static int const shift = fp32FractionBits - fp16FractionBits; // == 13
static int const shiftSign = 16;
static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)

static int32_t const infN = 0x7F800000; // flt32 infinity
static int32_t const maxN = 0x477FE000; // max flt16 normal as a flt32
static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
static int32_t const signN = 0x80000000; // flt32 sign bit

static int32_t const infC = infN >> shift;
Expand All @@ -183,37 +195,91 @@ class MSHADOW_ALIGNED(2) half_t {
static int32_t const minD = minC - subC - 1;

MSHADOW_XINLINE uint16_t float2half(const float& value) const {
Bits v, s;
Bits v;
v.f = value;
uint32_t sign = v.si & signN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
return v.ui | sign;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position

if (v.si <= maxZ) {
// Handle eventual zeros here to ensure vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
// The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is
// when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional
// bits to the right of the lsb are 1000... (including flt32 significand bits
// that may be lost during the above vshift). The first term below will always
// be true for vshift >=12 (since even the 'hidden bit' has been shifted to the
// right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make
// the proper test of the flt32 significand bits, including those lost during the vshift.
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent to 1, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
#endif
} else if (v.si <= maxN) {
// Handle norms
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent, possibly creating an inf, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
#endif
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}

v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}

// Same as above routine, except for addition of volatile keyword
MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*)
Bits v, s;
Bits v;
v.f = value;
uint32_t sign = v.si & signN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
return v.ui | sign;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position

if (v.si <= maxZ) {
// Handle eventual zeros here to ensure vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent to 1, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
#endif
} else if (v.si <= maxN) {
// Handle norms
#if MSHADOW_HALF_ROUND_TO_NEAREST == 1
// Rounding may increase the exponent, possibly creating an inf, but that's OK.
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
#endif
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}

v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}

MSHADOW_XINLINE float half2float(const uint16_t& value) const {
Expand Down

0 comments on commit 3dc8081

Please sign in to comment.