-
Notifications
You must be signed in to change notification settings - Fork 505
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add gfactored ot * pep8 and add doc * add exmaple for factotred OT * final number of PR * correct test on backends * remove useless loss * better tests
- Loading branch information
Showing
8 changed files
with
303 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ API and modules | |
partial | ||
sliced | ||
weak | ||
factored | ||
|
||
.. autosummary:: | ||
:toctree: ../modules/generated/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
========================================== | ||
Optimal transport with factored couplings | ||
========================================== | ||
Illustration of the factored coupling OT between 2D empirical distributions | ||
""" | ||
|
||
# Author: Remi Flamary <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
# sphinx_gallery_thumbnail_number = 2 | ||
|
||
import numpy as np | ||
import matplotlib.pylab as pl | ||
import ot | ||
import ot.plot | ||
|
||
# %% | ||
# Generate data an plot it | ||
# ------------------------ | ||
|
||
# parameters and data generation | ||
|
||
np.random.seed(42) | ||
|
||
n = 100 # nb samples | ||
|
||
xs = np.random.rand(n, 2) - .5 | ||
|
||
xs = xs + np.sign(xs) | ||
|
||
xt = np.random.rand(n, 2) - .5 | ||
|
||
a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples | ||
|
||
#%% plot samples | ||
|
||
pl.figure(1) | ||
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') | ||
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') | ||
pl.legend(loc=0) | ||
pl.title('Source and target distributions') | ||
|
||
|
||
# %% | ||
# Compute Factore OT and exact OT solutions | ||
# -------------------------------------- | ||
|
||
#%% EMD | ||
M = ot.dist(xs, xt) | ||
G0 = ot.emd(a, b, M) | ||
|
||
#%% factored OT OT | ||
|
||
Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4) | ||
|
||
|
||
# %% | ||
# Plot factored OT and exact OT solutions | ||
# -------------------------------------- | ||
|
||
pl.figure(2, (14, 4)) | ||
|
||
pl.subplot(1, 3, 1) | ||
ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1) | ||
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') | ||
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') | ||
pl.title('Exact OT with samples') | ||
|
||
pl.subplot(1, 3, 2) | ||
ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5) | ||
ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5) | ||
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') | ||
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') | ||
pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples') | ||
pl.title('Factored OT with template samples') | ||
|
||
pl.subplot(1, 3, 3) | ||
ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1) | ||
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') | ||
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') | ||
pl.title('Factored OT low rank OT plan') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
""" | ||
Factored OT solvers (low rank, cost or OT plan) | ||
""" | ||
|
||
# Author: Remi Flamary <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
from .backend import get_backend | ||
from .utils import dist | ||
from .lp import emd | ||
from .bregman import sinkhorn | ||
|
||
__all__ = ['factored_optimal_transport'] | ||
|
||
|
||
def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs): | ||
r"""Solves factored OT problem and return OT plans and intermediate distribution | ||
This function solve the following OT problem [40]_ | ||
.. math:: | ||
\mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) | ||
where : | ||
- :math:`\mu_a` and :math:`\mu_b` are empirical distributions. | ||
- :math:`\mu` is an empirical distribution with r samples | ||
And returns the two OT plans between | ||
.. note:: This function is backend-compatible and will work on arrays | ||
from all compatible backends. But the algorithm uses the C++ CPU backend | ||
which can lead to copy overhead on GPU arrays. | ||
Uses the conditional gradient algorithm to solve the problem proposed in | ||
:ref:`[39] <references-weak>`. | ||
Parameters | ||
---------- | ||
Xa : (ns,d) array-like, float | ||
Source samples | ||
Xb : (nt,d) array-like, float | ||
Target samples | ||
a : (ns,) array-like, float | ||
Source histogram (uniform weight if empty list) | ||
b : (nt,) array-like, float | ||
Target histogram (uniform weight if empty list)) | ||
numItermax : int, optional | ||
Max number of iterations | ||
stopThr : float, optional | ||
Stop threshold on the relative variation (>0) | ||
verbose : bool, optional | ||
Print information along iterations | ||
log : bool, optional | ||
record log if True | ||
Returns | ||
------- | ||
Ga: array-like, shape (ns, r) | ||
Optimal transportation matrix between source and the intermediate | ||
distribution | ||
Gb: array-like, shape (r, nt) | ||
Optimal transportation matrix between the intermediate and target | ||
distribution | ||
X: array-like, shape (r, d) | ||
Support of the intermediate distribution | ||
log: dict, optional | ||
If input log is true, a dictionary containing the cost and dual | ||
variables and exit status | ||
.. _references-factored: | ||
References | ||
---------- | ||
.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, | ||
G., & Weed, J. (2019, April). Statistical optimal transport via factored | ||
couplings. In The 22nd International Conference on Artificial | ||
Intelligence and Statistics (pp. 2454-2465). PMLR. | ||
See Also | ||
-------- | ||
ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General | ||
regularized OT | ||
""" | ||
|
||
nx = get_backend(Xa, Xb) | ||
|
||
n_a = Xa.shape[0] | ||
n_b = Xb.shape[0] | ||
d = Xa.shape[1] | ||
|
||
if a is None: | ||
a = nx.ones((n_a), type_as=Xa) / n_a | ||
if b is None: | ||
b = nx.ones((n_b), type_as=Xb) / n_b | ||
|
||
if X0 is None: | ||
X = nx.randn(r, d, type_as=Xa) | ||
else: | ||
X = X0 | ||
|
||
w = nx.ones(r, type_as=Xa) / r | ||
|
||
def solve_ot(X1, X2, w1, w2): | ||
M = dist(X1, X2) | ||
if reg > 0: | ||
G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs) | ||
log['cost'] = nx.sum(G * M) | ||
return G, log | ||
else: | ||
return emd(w1, w2, M, log=True, **kwargs) | ||
|
||
norm_delta = [] | ||
|
||
# solve the barycenter | ||
for i in range(numItermax): | ||
|
||
old_X = X | ||
|
||
# solve OT with template | ||
Ga, loga = solve_ot(Xa, X, a, w) | ||
Gb, logb = solve_ot(X, Xb, w, b) | ||
|
||
X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r | ||
|
||
delta = nx.norm(X - old_X) | ||
if delta < stopThr: | ||
break | ||
if log: | ||
norm_delta.append(delta) | ||
|
||
if log: | ||
log_dic = {'delta_iter': norm_delta, | ||
'ua': loga['u'], | ||
'va': loga['v'], | ||
'ub': logb['u'], | ||
'vb': logb['v'], | ||
'costa': loga['cost'], | ||
'costb': logb['cost'], | ||
} | ||
return Ga, Gb, X, log_dic | ||
|
||
return Ga, Gb, X |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Tests for main module ot.weak """ | ||
|
||
# Author: Remi Flamary <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
import ot | ||
import numpy as np | ||
|
||
|
||
def test_factored_ot(): | ||
# test weak ot solver and identity stationary point | ||
n = 50 | ||
rng = np.random.RandomState(0) | ||
|
||
xs = rng.randn(n, 2) | ||
xt = rng.randn(n, 2) | ||
u = ot.utils.unif(n) | ||
|
||
Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True) | ||
|
||
# check constraints | ||
np.testing.assert_allclose(u, Ga.sum(1)) | ||
np.testing.assert_allclose(u, Gb.sum(0)) | ||
|
||
Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True) | ||
|
||
# check constraints | ||
np.testing.assert_allclose(u, Ga.sum(1)) | ||
np.testing.assert_allclose(u, Gb.sum(0)) | ||
|
||
|
||
def test_factored_ot_backends(nx): | ||
# test weak ot solver for different backends | ||
n = 50 | ||
rng = np.random.RandomState(0) | ||
|
||
xs = rng.randn(n, 2) | ||
xt = rng.randn(n, 2) | ||
u = ot.utils.unif(n) | ||
|
||
xs2 = nx.from_numpy(xs) | ||
xt2 = nx.from_numpy(xt) | ||
u2 = nx.from_numpy(u) | ||
|
||
Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10) | ||
|
||
# check constraints | ||
np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) | ||
np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) | ||
|
||
Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2) | ||
|
||
# check constraints | ||
np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) | ||
np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) |