Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
GELU
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Mar 16, 2019
1 parent a091d36 commit e0a565c
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 2 deletions.
19 changes: 19 additions & 0 deletions python/mxnet/gluon/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,25 @@ def __init__(self, **kwargs):
def hybrid_forward(self, F, x):
return F.LeakyReLU(x, act_type='selu', name='fwd')

class GELU(HybridBlock):
r"""
Gaussian Exponential Linear Unit (GELU)
"Gaussian Error Linear Units (GELUs)", Hendrycks et al, 2016
https://arxiv.org/abs/1606.08415
Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, **kwargs):
super(GELU, self).__init__(**kwargs)

def hybrid_forward(self, F, x):
return F.LeakyReLU(x, act_type='gelu', name='fwd')


class Swish(HybridBlock):
r"""
Expand Down
21 changes: 19 additions & 2 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace op {
namespace leakyrelu {
enum LeakyReLUOpInputs {kData, kGamma};
enum LeakyReLUOpOutputs {kOut, kMask};
enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kSELU};
enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kSELU, kGELU};
enum LeakyReLUOpResource {kRandom};
} // namespace leakyrelu

Expand All @@ -64,6 +64,7 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
.add_enum("prelu", leakyrelu::kPReLU)
.add_enum("elu", leakyrelu::kELU)
.add_enum("selu", leakyrelu::kSELU)
.add_enum("gelu", leakyrelu::kGELU)
.describe("Activation function to be applied.");
DMLC_DECLARE_FIELD(slope).set_default(0.25f)
.describe("Init slope for the activation. (For leaky and elu only)");
Expand Down Expand Up @@ -190,6 +191,13 @@ class LeakyReLUOp : public Operator {
});
break;
}
case leakyrelu::kGELU: {
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::gelu, Req>, xpu>::Launch(
s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_);
});
break;
}
default:
LOG(FATAL) << "Not implmented";
}
Expand Down Expand Up @@ -223,7 +231,7 @@ class LeakyReLUOp : public Operator {
if (param_.act_type == leakyrelu::kRReLU) {
mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 3, DType>(dshape, s);
}
if (param_.act_type == leakyrelu::kPReLU) {
if (param_.act_type == leakyrelu::kPReLU || param_.act_type == leakyrelu::kGELU) {
data = in_data[leakyrelu::kData].get_with_shape<xpu, 3, DType>(dshape, s);
}
switch (param_.act_type) {
Expand Down Expand Up @@ -287,6 +295,15 @@ class LeakyReLUOp : public Operator {
});
break;
}
case leakyrelu::kGELU: {
MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<
mxnet_op::backward_grad_tuned<mshadow_op::gelu_grad>, Req>, xpu>::Launch(
s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
data.dptr_, output.dptr_);
});
break;
}
default:
LOG(FATAL) << "Not implmented";
}
Expand Down
18 changes: 18 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ namespace mshadow_op {
__constant__ const float PI = 3.14159265358979323846;
__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
__constant__ const float GELU_CUBIC_CONSTANT = 0.044715;
__constant__ const float GELU_ROOT_2_OVER_PI = 0.7978845608028654;
#else
const float PI = 3.14159265358979323846;
const float SELU_ALPHA = 1.6732632423543772848170429916717;
const float SELU_LAMBDA = 1.0507009873554804934193349852946;
const float GELU_CUBIC_CONSTANT = 0.044715;
const float GELU_ROOT_2_OVER_PI = 0.7978845608028654;
using std::isnan;
#endif
using std::enable_if;
Expand Down Expand Up @@ -127,6 +131,20 @@ MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));

#define gelu_gx(a) \
a * (DType(1.0f) + DType(GELU_CUBIC_CONSTANT) * a * a)

#define gelu_gx_grad(a) \
(DType(1.0f) + DType(3.0f * GELU_CUBIC_CONSTANT) * a * a)

#define gelu_tanh(a) \
math::tanh(DType(GELU_ROOT_2_OVER_PI) * gelu_gx(a))

MXNET_UNARY_MATH_OP(gelu, DType(0.5f) * a * (DType(1.0f) + gelu_tanh(a)));

MXNET_BINARY_MATH_OP_NC(gelu_grad,
b / a + b * (DType(1.0f) - gelu_tanh(a)) * DType(GELU_ROOT_2_OVER_PI) * gelu_gx_grad(a));

MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
(a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));

Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::selu); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::selu_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::gelu); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu); // NOLINT()
Expand Down Expand Up @@ -328,6 +329,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::elu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum); // NOLINT()
Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,39 @@ def fselu_grad(grad, x, y):
check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype)


@with_seed()
def test_gelu():
CUBE_CONSTANT = 0.044715
TWO_OVER_PI = 0.7978845608028654
def g(x):
return TWO_OVER_PI * (x + CUBE_CONSTANT * np.power(x, 3))
def g_grad(x):
return TWO_OVER_PI * (1.0 + 3.0 * CUBE_CONSTANT * np.power(x, 2))
def f(x):
return 1.0 + np.tanh(g(x))
def f_grad(x):
return (1.0 - np.tanh(g(x)) * np.tanh(g(x))) * g_grad(x)
def fgelu(x):
return 0.5 * x * f(x)
def fgelu_grad(grad, x, y):
return grad * (y / x + y * (1 - np.tanh(g(x))) * g_grad(x))

shape = (3, 4)
x = mx.sym.Variable("x")
y = mx.sym.LeakyReLU(data=x, act_type="gelu")
for dtype in [np.float16, np.float32, np.float64]:
xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype)
eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4)
if dtype is np.float16:
xa /= 10.0
xa[abs(xa) < eps] = 0.01
ya = fgelu(xa)
ga = fgelu_grad(np.ones(shape).astype(dtype), xa, ya)
check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype)


@with_seed()
def test_sigmoid():
def fsigmoid(a):
Expand Down

0 comments on commit e0a565c

Please sign in to comment.