Skip to content

Commit

Permalink
Added logsumexp to backend. (#6346)
Browse files Browse the repository at this point in the history
  • Loading branch information
phipleg authored and fchollet committed Apr 22, 2017
1 parent 70ffba0 commit 7d52af6
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
22 changes: 22 additions & 0 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,28 @@ def log(x):
return tf.log(x)


def logsumexp(x, axis=None, keepdims=False):
"""Computes log(sum(exp(elements across dimensions of a tensor))).
This function is more numerically stable than log(sum(exp(x))).
It avoids overflows caused by taking the exp of large inputs and
underflows caused by taking the log of small inputs.
# Arguments
x: A tensor or variable.
axis: An integer, the axis to reduce over.
keepdims: A boolean, whether to keep the dimensions or not.
If `keepdims` is `False`, the rank of the tensor is reduced
by 1. If `keepdims` is `True`, the reduced dimension is
retained with length 1.
# Returns
The reduced tensor.
"""
axis = _normalize_axis(axis, ndim(x))
return tf.reduce_logsumexp(x, reduction_indices=axis, keep_dims=keepdims)


def round(x):
"""Element-wise rounding to the closest integer.
Expand Down
23 changes: 23 additions & 0 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,29 @@ def log(x):
return T.log(x)


def logsumexp(x, axis=None, keepdims=False):
"""Computes log(sum(exp(elements across dimensions of a tensor))).
This function is more numerically stable than log(sum(exp(x))).
It avoids overflows caused by taking the exp of large inputs and
underflows caused by taking the log of small inputs.
# Arguments
x: A tensor or variable.
axis: An integer, the axis to reduce over.
keepdims: A boolean, whether to keep the dimensions or not.
If `keepdims` is `False`, the rank of the tensor is reduced
by 1. If `keepdims` is `True`, the reduced dimension is
retained with length 1.
# Returns
The reduced tensor.
"""
# Theano has a built-in optimization for logsumexp (see https://github.com/Theano/Theano/pull/4736)
# so we can just write the expression directly:
return T.log(T.sum(T.exp(x), axis=axis, keepdims=keepdims))


def round(x):
return T.round(x, mode='half_to_even')

Expand Down
35 changes: 35 additions & 0 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,41 @@ def step_function(x, states):
assert_allclose(tf_last_output, th_last_output, atol=1e-04)
assert_allclose(tf_outputs, th_outputs, atol=1e-04)

@pytest.mark.parametrize('x_np,axis,keepdims', [
(np.array([1.1, 0.8, 0.9]), 0, False),
(np.array([[1.1, 0.8, 0.9]]), 0, False),
(np.array([[1.1, 0.8, 0.9]]), 1, False),
(np.array([[1.1, 0.8, 0.9]]), -1, False),
(np.array([[1.1, 0.8, 0.9]]), 1, True),
(np.array([[1.1], [1.2]]), 0, False),
(np.array([[1.1], [1.2]]), 1, False),
(np.array([[1.1], [1.2]]), -1, False),
(np.array([[1.1], [1.2]]), -1, True),
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), None, False),
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 0, False),
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), 1, False),
(np.array([[1.1, 1.2, 1.3], [0.9, 0.7, 1.4]]), -1, False),
])
@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])
def test_logsumexp(self, x_np, axis, keepdims, K):
'''
Check if K.logsumexp works properly for values close to one.
'''
x = K.variable(x_np)
assert_allclose(K.eval(K.logsumexp(x, axis=axis, keepdims=keepdims)),
np.log(np.sum(np.exp(x_np), axis=axis, keepdims=keepdims)),
rtol=1e-5)

@pytest.mark.parametrize('K', [KTH, KTF], ids=["KTH", "KTF"])
def test_logsumexp_optim(self, K):
'''
Check if optimization works.
'''
x_np = np.array([1e+4, 1e-4])
assert_allclose(K.eval(K.logsumexp(K.variable(x_np), axis=0)),
1e4,
rtol=1e-5)

def test_switch(self):
val = np.random.random()
xth = KTH.variable(val)
Expand Down

0 comments on commit 7d52af6

Please sign in to comment.