Skip to content

Commit

Permalink
Implementation of Low Rank Gromov-Wasserstein (#614)
Browse files Browse the repository at this point in the history
* new file for lr sinkhorn

* lr sinkhorn, solve_sample, OTResultLazy

* add test functions + small modif lr_sin/solve_sample

* add import to __init__

* modify low rank, remove solve_sample,OTResultLazy

* new file for lr sinkhorn

* lr sinkhorn, solve_sample, OTResultLazy

* add test functions + small modif lr_sin/solve_sample

* add import to __init__

* remove test solve_sample

* add value, value_linear, lazy_plan

* add comments to lr algorithm

* modify test functions + add comments to lowrank

* modify __init__ with lowrank

* debug lowrank + test

* debug test function low_rank

* error test

* final debug of lowrank + add new test functions

* Debug tests + add lowrank to solve_sample

* fix torch backend for lowrank

* fix jax backend and skip tf

* fix pep 8 tests

* add lowrank init + test functions

* Add init strategies in lowrank + example (#588)

* modified lowrank

* changes from code review

* fix error test pep8

* fix linux-minimal-deps + code review

* Implementation of LR GW + add method in __init__

* add LR gw paper in README.md

* add tests for low rank GW

* add examples for Low Rank GW

* fix __init__

* change atol of lr backends

* fix pep8 errors

* modif for code review
  • Loading branch information
laudavid authored May 29, 2024
1 parent e01c4e6 commit 2472dd4
Show file tree
Hide file tree
Showing 8 changed files with 622 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The contributors to this library are:
* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn)
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn, Low rank Gromov-Wasserstein samples)

## Acknowledgments

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).

[66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. [Entropic estimation of optimal transport maps](https://arxiv.org/pdf/2109.12004.pdf). arXiv preprint arXiv:2109.12004 (2021).

[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
+ Continuous entropic mapping (PR #613)
+ New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620)
+ Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605).
+ Added support for [Low rank Gromov-Wasserstein](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf) with `ot.gromov.lowrank_gromov_wasserstein_samples` (PR #614)

#### Closed issues
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
Expand Down
173 changes: 173 additions & 0 deletions examples/others/plot_lowrank_GW.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
"""
========================================
Low rank Gromov-Wasterstein between samples
========================================
Comparaison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67]
on two curves in 2D and 3D, both sampled with 200 points.
The squared Euclidean distance is considered as the ground cost for both samples.
[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022).
"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs".
In International Conference on Machine Learning (ICML), 2022.
"""

# Author: Laurène David <[email protected]>
#
# License: MIT License
#
# sphinx_gallery_thumbnail_number = 3

#%%
import numpy as np
import matplotlib.pylab as pl
import ot.plot
import time

##############################################################################
# Generate data
# -------------

#%% parameters
n_samples = 200

# Generate 2D and 3D curves
theta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples)
z = np.linspace(1, 2, n_samples)
r = z**2 + 1
x = r * np.sin(theta)
y = r * np.cos(theta)

# Source and target distribution
X = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1)
Y = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1)


##############################################################################
# Plot data
# ------------

#%%
# Plot the source and target samples
fig = pl.figure(1, figsize=(10, 4))

ax = fig.add_subplot(121)
ax.plot(X[:, 0], X[:, 1], color="blue", linewidth=6)
ax.tick_params(left=False, right=False, labelleft=False,
labelbottom=False, bottom=False)
ax.set_title("2D curve (source)")

ax2 = fig.add_subplot(122, projection="3d")
ax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c='red', linewidth=6)
ax2.tick_params(left=False, right=False, labelleft=False,
labelbottom=False, bottom=False)
ax2.view_init(15, -50)
ax2.set_title("3D curve (target)")

pl.tight_layout()
pl.show()


##############################################################################
# Entropic Gromov-Wasserstein
# ------------

#%%

# Compute cost matrices
C1 = ot.dist(X, X, metric="sqeuclidean")
C2 = ot.dist(Y, Y, metric="sqeuclidean")

# Scale cost matrices
r1 = C1.max()
r2 = C2.max()

C1 = C1 / r1
C2 = C2 / r2


# Solve entropic gw
reg = 5 * 1e-3

start = time.time()
gw, log = ot.gromov.entropic_gromov_wasserstein(
C1, C2, tol=1e-3, epsilon=reg,
log=True, verbose=False)

end = time.time()
time_entropic = end - start

entropic_gw_loss = np.round(log['gw_dist'], 3)

# Plot entropic gw
pl.figure(2)
pl.imshow(gw, interpolation="nearest", aspect="auto")
pl.title("Entropic Gromov-Wasserstein (loss={})".format(entropic_gw_loss))
pl.show()


##############################################################################
# Low rank squared euclidean cost matrices
# ------------
# %%

# Compute the low rank sqeuclidean cost decompositions
A1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False)
B1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False)

# Scale the low rank cost matrices
A1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1)
B1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2)


##############################################################################
# Low rank Gromov-Wasserstein
# ------------
# %%

# Solve low rank gromov-wasserstein with different ranks
list_rank = [10, 50]
list_P_GW = []
list_loss_GW = []
list_time_GW = []

for rank in list_rank:
start = time.time()

Q, R, g, log = ot.lowrank_gromov_wasserstein_samples(
X, Y, reg=0, rank=rank, rescale_cost=False, cost_factorized_Xs=(A1, A2),
cost_factorized_Xt=(B1, B2), seed_init=49, numItermax=1000, log=True, stopThr=1e-6,
)
end = time.time()

P = log["lazy_plan"][:]
loss = log["value"]

list_P_GW.append(P)
list_loss_GW.append(np.round(loss, 3))
list_time_GW.append(end - start)


# %%
# Plot low rank GW with different ranks
pl.figure(3, figsize=(10, 4))

pl.subplot(1, 2, 1)
pl.imshow(list_P_GW[0], interpolation="nearest", aspect="auto")
pl.title('Low rank GW (rank=10, loss={})'.format(list_loss_GW[0]))

pl.subplot(1, 2, 2)
pl.imshow(list_P_GW[1], interpolation="nearest", aspect="auto")
pl.title('Low rank GW (rank=50, loss={})'.format(list_loss_GW[1]))

pl.tight_layout()
pl.show()


# %%
# Compare computation time between entropic GW and low rank GW
print("Entropic GW: {:.2f}s".format(time_entropic))
print("Low rank GW (rank=10): {:.2f}s".format(list_time_GW[0]))
print("Low rank GW (rank=50): {:.2f}s".format(list_time_GW[1]))
7 changes: 4 additions & 3 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif)
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2,
lowrank_gromov_wasserstein_samples)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve, solve_gromov, solve_sample
Expand All @@ -71,5 +72,5 @@
'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif',
'lowrank_sinkhorn']
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn',
'lowrank_gromov_wasserstein_samples']
4 changes: 3 additions & 1 deletion ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
fused_gromov_wasserstein_dictionary_learning,
fused_gromov_wasserstein_linear_unmixing)

from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples)


__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss',
'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed',
Expand All @@ -64,4 +66,4 @@
'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein',
'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning',
'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning',
'fused_gromov_wasserstein_linear_unmixing']
'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples']
Loading

0 comments on commit 2472dd4

Please sign in to comment.