Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Make partial_wasserstein, partial_wasserstein2 and entropic_partial_wasserstein work with backend #449

Merged
merged 15 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)
- New API for OT solver using function `ot.solve` (PR #388)
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449)
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current Pymanopt (PR #443)
Expand Down
95 changes: 55 additions & 40 deletions ot/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,

nx = get_backend(a, b, M)

if nx.sum(a) > 1 or nx.sum(b) > 1:
if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors
raise ValueError("Problem infeasible. Check that a and b are in the "
"simplex")

Expand Down Expand Up @@ -270,36 +270,43 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):

nx = get_backend(a, b, M)

dim_a, dim_b = M.shape
if len(a) == 0:
a = nx.ones(dim_a, type_as=a) / dim_a
if len(b) == 0:
b = nx.ones(dim_b, type_as=b) / dim_b

if m is None:
return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
elif m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
elif m > nx.min((nx.sum(a), nx.sum(b))):
elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")

a0, b0, M0 = a, b, M
# convert to humpy
a, b, M = nx.to_numpy(a, b, M)

b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
M_extended = np.zeros((len(a_extended), len(b_extended)))
M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2
M_extended[:len(a), :len(b)] = M
b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies
b_extended = nx.concatenate((b, b_extension))
a_extension = nx.ones(nb_dummies, type_as=a) * (nx.sum(b) - m) / nb_dummies
a_extended = nx.concatenate((a, a_extension))
M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2
M_extended = nx.concatenate(
(nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1),
nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)),
axis=0
)

gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
**kwargs)

gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M)
gamma = gamma[:len(a), :len(b)]

if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
log_emd['partial_w_dist'] = nx.sum(M0 * gamma)
log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0)
log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0)
log_emd['partial_w_dist'] = nx.sum(M * gamma)
log_emd['u'] = log_emd['u'][:len(a)]
log_emd['v'] = log_emd['v'][:len(b)]

if log:
return gamma, log_emd
Expand Down Expand Up @@ -389,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
NeurIPS.
"""

a, b, M = list_to_array(a, b, M)

nx = get_backend(a, b, M)

partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
**kwargs)
log_w['T'] = partial_gw

if log:
return np.sum(partial_gw * M), log_w
return nx.sum(partial_gw * M), log_w
else:
return np.sum(partial_gw * M)
return nx.sum(partial_gw * M)


def gwgrad_partial(C1, C2, T):
Expand Down Expand Up @@ -838,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
ot.partial.partial_wasserstein: exact Partial Wasserstein
"""

a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
a, b, M = list_to_array(a, b, M)

nx = get_backend(a, b, M)

dim_a, dim_b = M.shape
dx = np.ones(dim_a, dtype=np.float64)
dy = np.ones(dim_b, dtype=np.float64)
dx = nx.ones(dim_a, type_as=a)
dy = nx.ones(dim_b, type_as=b)

if len(a) == 0:
a = np.ones(dim_a, dtype=np.float64) / dim_a
a = nx.ones(dim_a, type_as=a) / dim_a
if len(b) == 0:
b = np.ones(dim_b, dtype=np.float64) / dim_b
b = nx.ones(dim_b, type_as=b) / dim_b

if m is None:
m = np.min((np.sum(a), np.sum(b))) * 1.0
m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0
if m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
if m > np.min((np.sum(a), np.sum(b))):
if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")

log_e = {'err': []}

# Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
np.multiply(K, m / np.sum(K), out=K)
if type(a) == type(b) == type(M) == np.ndarray:
# Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
np.multiply(K, m / np.sum(K), out=K)
else:
K = nx.exp(-M / reg)
K = K * m / nx.sum(K)

err, cpt = 1, 0
q1 = np.ones(K.shape)
q2 = np.ones(K.shape)
q3 = np.ones(K.shape)
q1 = nx.ones(K.shape, type_as=K)
q2 = nx.ones(K.shape, type_as=K)
q3 = nx.ones(K.shape, type_as=K)

while (err > stopThr and cpt < numItermax):
Kprev = K
K = K * q1
K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K)
q1 = q1 * Kprev / K1
K1prev = K1
K1 = K1 * q2
K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy)))
q2 = q2 * K1prev / K2
K2prev = K2
K2 = K2 * q3
K = K2 * (m / np.sum(K2))
K = K2 * (m / nx.sum(K2))
q3 = q3 * K2prev / K

if np.any(np.isnan(K)) or np.any(np.isinf(K)):
if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)):
print('Warning: numerical errors at iteration', cpt)
break
if cpt % 10 == 0:
err = np.linalg.norm(Kprev - K)
err = nx.norm(Kprev - K)
if log:
log_e['err'].append(err)
if verbose:
Expand All @@ -901,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
print('{:5d}|{:8e}|'.format(cpt, err))

cpt = cpt + 1
log_e['partial_w_dist'] = np.sum(M * K)
log_e['partial_w_dist'] = nx.sum(M * K)
if log:
return K, log_e
else:
Expand Down
113 changes: 92 additions & 21 deletions test/test_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import scipy as sp
import ot
from ot.backend import to_numpy, torch
import pytest


Expand Down Expand Up @@ -82,7 +83,7 @@ def test_partial_wasserstein_lagrange():
w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True)


def test_partial_wasserstein():
def test_partial_wasserstein(nx):

n_samples = 20 # nb samples (gaussian)
n_noise = 20 # nb of samples (noise)
Expand All @@ -102,25 +103,20 @@ def test_partial_wasserstein():

m = 0.5

p, q, M = nx.from_numpy(p, q, M)

w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
log=True, verbose=True)
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True)

# check constraints
np.testing.assert_equal(
w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
np.testing.assert_equal(
w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))

# check transported mass
np.testing.assert_allclose(
np.sum(w0), m, atol=1e-04)
np.testing.assert_allclose(
np.sum(w), m, atol=1e-04)
np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04)
np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04)

w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False)
Expand All @@ -130,12 +126,87 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)

# check constraints
np.testing.assert_equal(
G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
np.testing.assert_allclose(
np.sum(G), m, atol=1e-04)
np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p))
np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q))
np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04)

empty_array = nx.zeros(0, type_as=M)
w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None)

# check constraints
np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))

# check transported mass
np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04)

w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None)

# check constraints
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))

# check transported mass
np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04)


def test_partial_wasserstein2_gradient():
if torch:
n_samples = 40

mu = np.array([0, 0])
cov = np.array([[1, 0], [0, 2]])

xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)

M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)

p = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
q = torch.tensor(ot.unif(n_samples), dtype=torch.float64)

m = 0.5

w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)

w.backward()

assert M.grad is not None
assert M.grad.shape == M.shape


def test_entropic_partial_wasserstein_gradient():
if torch:
n_samples = 40

mu = np.array([0, 0])
cov = np.array([[1, 0], [0, 2]])

xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)

M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)

p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)

m = 0.5
reg = 1

_, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True)

log['partial_w_dist'].backward()

assert M.grad is not None
assert p.grad is not None
assert q.grad is not None
assert M.grad.shape == M.shape
assert p.grad.shape == p.shape
assert q.grad.shape == q.shape


def test_partial_gromov_wasserstein():
Expand Down