Skip to content

Commit

Permalink
[MRG] Add factored coupling (#358)
Browse files Browse the repository at this point in the history
* add gfactored ot

* pep8 and add doc

* add exmaple for factotred OT

* final number of PR

* correct test on backends

* remove useless loss

* better tests
  • Loading branch information
rflamary authored Mar 24, 2022
1 parent 7671715 commit 82452e0
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 2 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,6 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.

[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.

[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#### New features

- Implementation of factored OT with emd and sinkhorn (PR #358).
- A brand new logo for POT (PR #357)
- Better list of related examples in quick start guide with `minigallery` (PR #334).
- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ API and modules
partial
sliced
weak
factored

.. autosummary::
:toctree: ../modules/generated/
Expand Down
86 changes: 86 additions & 0 deletions examples/others/plot_factored_coupling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
"""
==========================================
Optimal transport with factored couplings
==========================================
Illustration of the factored coupling OT between 2D empirical distributions
"""

# Author: Remi Flamary <[email protected]>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot

# %%
# Generate data an plot it
# ------------------------

# parameters and data generation

np.random.seed(42)

n = 100 # nb samples

xs = np.random.rand(n, 2) - .5

xs = xs + np.sign(xs)

xt = np.random.rand(n, 2) - .5

a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples

#%% plot samples

pl.figure(1)
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')


# %%
# Compute Factore OT and exact OT solutions
# --------------------------------------

#%% EMD
M = ot.dist(xs, xt)
G0 = ot.emd(a, b, M)

#%% factored OT OT

Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4)


# %%
# Plot factored OT and exact OT solutions
# --------------------------------------

pl.figure(2, (14, 4))

pl.subplot(1, 3, 1)
ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1)
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.title('Exact OT with samples')

pl.subplot(1, 3, 2)
ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5)
ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5)
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples')
pl.title('Factored OT with template samples')

pl.subplot(1, 3, 3)
ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1)
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.title('Factored OT low rank OT plan')
5 changes: 5 additions & 0 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from . import backend
from . import regpath
from . import weak
from . import factored

# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
Expand All @@ -44,6 +45,9 @@
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport


# utils functions
from .utils import dist, unif, tic, toc, toq

Expand All @@ -57,4 +61,5 @@
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
'factored_optimal_transport',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
145 changes: 145 additions & 0 deletions ot/factored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Factored OT solvers (low rank, cost or OT plan)
"""

# Author: Remi Flamary <[email protected]>
#
# License: MIT License

from .backend import get_backend
from .utils import dist
from .lp import emd
from .bregman import sinkhorn

__all__ = ['factored_optimal_transport']


def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs):
r"""Solves factored OT problem and return OT plans and intermediate distribution
This function solve the following OT problem [40]_
.. math::
\mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)
where :
- :math:`\mu_a` and :math:`\mu_b` are empirical distributions.
- :math:`\mu` is an empirical distribution with r samples
And returns the two OT plans between
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.
Uses the conditional gradient algorithm to solve the problem proposed in
:ref:`[39] <references-weak>`.
Parameters
----------
Xa : (ns,d) array-like, float
Source samples
Xb : (nt,d) array-like, float
Target samples
a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
b : (nt,) array-like, float
Target histogram (uniform weight if empty list))
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on the relative variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
Ga: array-like, shape (ns, r)
Optimal transportation matrix between source and the intermediate
distribution
Gb: array-like, shape (r, nt)
Optimal transportation matrix between the intermediate and target
distribution
X: array-like, shape (r, d)
Support of the intermediate distribution
log: dict, optional
If input log is true, a dictionary containing the cost and dual
variables and exit status
.. _references-factored:
References
----------
.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
G., & Weed, J. (2019, April). Statistical optimal transport via factored
couplings. In The 22nd International Conference on Artificial
Intelligence and Statistics (pp. 2454-2465). PMLR.
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
regularized OT
"""

nx = get_backend(Xa, Xb)

n_a = Xa.shape[0]
n_b = Xb.shape[0]
d = Xa.shape[1]

if a is None:
a = nx.ones((n_a), type_as=Xa) / n_a
if b is None:
b = nx.ones((n_b), type_as=Xb) / n_b

if X0 is None:
X = nx.randn(r, d, type_as=Xa)
else:
X = X0

w = nx.ones(r, type_as=Xa) / r

def solve_ot(X1, X2, w1, w2):
M = dist(X1, X2)
if reg > 0:
G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs)
log['cost'] = nx.sum(G * M)
return G, log
else:
return emd(w1, w2, M, log=True, **kwargs)

norm_delta = []

# solve the barycenter
for i in range(numItermax):

old_X = X

# solve OT with template
Ga, loga = solve_ot(Xa, X, a, w)
Gb, logb = solve_ot(X, Xb, w, b)

X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r

delta = nx.norm(X - old_X)
if delta < stopThr:
break
if log:
norm_delta.append(delta)

if log:
log_dic = {'delta_iter': norm_delta,
'ua': loga['u'],
'va': loga['v'],
'ub': logb['u'],
'vb': logb['v'],
'costa': loga['cost'],
'costb': logb['cost'],
}
return Ga, Gb, X, log_dic

return Ga, Gb, X
7 changes: 6 additions & 1 deletion ot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
if ('color' not in kwargs) and ('c' not in kwargs):
kwargs['color'] = 'k'
mx = G.max()
if 'alpha' in kwargs:
scale = kwargs['alpha']
del kwargs['alpha']
else:
scale = 1
for i in range(xs.shape[0]):
for j in range(xt.shape[0]):
if G[i, j] / mx > thr:
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
alpha=G[i, j] / mx, **kwargs)
alpha=G[i, j] / mx * scale, **kwargs)
56 changes: 56 additions & 0 deletions test/test_factored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Tests for main module ot.weak """

# Author: Remi Flamary <[email protected]>
#
# License: MIT License

import ot
import numpy as np


def test_factored_ot():
# test weak ot solver and identity stationary point
n = 50
rng = np.random.RandomState(0)

xs = rng.randn(n, 2)
xt = rng.randn(n, 2)
u = ot.utils.unif(n)

Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True)

# check constraints
np.testing.assert_allclose(u, Ga.sum(1))
np.testing.assert_allclose(u, Gb.sum(0))

Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True)

# check constraints
np.testing.assert_allclose(u, Ga.sum(1))
np.testing.assert_allclose(u, Gb.sum(0))


def test_factored_ot_backends(nx):
# test weak ot solver for different backends
n = 50
rng = np.random.RandomState(0)

xs = rng.randn(n, 2)
xt = rng.randn(n, 2)
u = ot.utils.unif(n)

xs2 = nx.from_numpy(xs)
xt2 = nx.from_numpy(xt)
u2 = nx.from_numpy(u)

Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10)

# check constraints
np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))

Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2)

# check constraints
np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1))
np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0))

0 comments on commit 82452e0

Please sign in to comment.