Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers #431

Merged
merged 39 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
574b003
maj gw/ srgw/ generic cg solver
cedricvincentcuaz Sep 22, 2022
ff1e28d
correct pep8 on current state
cedricvincentcuaz Sep 22, 2022
6d879fe
fix bug previous tests
cedricvincentcuaz Sep 22, 2022
f80671a
fix pep8
cedricvincentcuaz Sep 22, 2022
ee330c0
fix bug srGW constC in loss and gradient
cedricvincentcuaz Sep 23, 2022
9670fc9
Merge branch 'master' into semirelaxed_gromov
rflamary Sep 27, 2022
adaae47
fix doc html
cedricvincentcuaz Sep 27, 2022
8d9827e
Merge branch 'semirelaxed_gromov' of https://github.com/cedricvincent…
cedricvincentcuaz Sep 27, 2022
7c758e3
fix doc html
cedricvincentcuaz Sep 27, 2022
98a880e
Merge branch 'master' into semirelaxed_gromov
rflamary Jan 4, 2023
c384f45
update generic_cg and dependencies
cedricvincentcuaz Feb 2, 2023
e6f0bb1
start updating test_optim.py
cedricvincentcuaz Feb 2, 2023
efbba2e
update tests gromov and optim - plus fix gromov dependencies
cedricvincentcuaz Feb 6, 2023
71be9d0
add symmetry feature to entropic gw
cedricvincentcuaz Feb 6, 2023
90fcd48
add symmetry feature to entropic gw
cedricvincentcuaz Feb 6, 2023
6ab9514
add exemple for sr(F)GW matchings
cedricvincentcuaz Feb 9, 2023
dc53fe1
Merge branch 'master' into semirelaxed_gromov
rflamary Feb 23, 2023
43bc857
Merge branch 'master' into semirelaxed_gromov
rflamary Feb 23, 2023
a9cbe08
factor linesearch dependencies /transpose + srgw to backend unfinished
cedricvincentcuaz Feb 27, 2023
6202290
merge releases.md
cedricvincentcuaz Feb 27, 2023
f7fa3ee
small stuff
cedricvincentcuaz Feb 27, 2023
54f0ba1
remove (reg,M) from line-search/ complete srgw tests with backend
cedricvincentcuaz Feb 28, 2023
a875a12
remove backend repetitions / rename fG to costG/ fix innerlog to True
cedricvincentcuaz Feb 28, 2023
92c69d4
fix pep8
cedricvincentcuaz Feb 28, 2023
be55ea2
Merge branch 'master' into semirelaxed_gromov
rflamary Feb 28, 2023
6069b0a
take comments into account / new nx parameters still to test
cedricvincentcuaz Mar 1, 2023
e50a750
Merge branch 'semirelaxed_gromov' of https://github.com/cedricvincent…
cedricvincentcuaz Mar 1, 2023
746bc1d
factor (f)gw2 + test new backend parameters in ot.gromov + harmonize …
cedricvincentcuaz Mar 3, 2023
b028b36
split gromov.py in ot/gromov/ + update test_gromov with helper_backen…
cedricvincentcuaz Mar 3, 2023
49fdbeb
manual documentaion gromov
rflamary Mar 9, 2023
95f2033
remove circular autosummary
rflamary Mar 9, 2023
5882cd1
Merge branch 'master' into semirelaxed_gromov
rflamary Mar 9, 2023
2409983
trying stuff
rflamary Mar 9, 2023
a1f1172
Merge branch 'semirelaxed_gromov' of https://github.com/cedricvincent…
rflamary Mar 9, 2023
dc1ac92
debug documentation
rflamary Mar 9, 2023
97c1e6b
alphabetic ordering of module
rflamary Mar 9, 2023
2d03ba6
merge into branch
cedricvincentcuaz Mar 9, 2023
679b74f
Merge branch 'semirelaxed_gromov' of https://github.com/cedricvincent…
cedricvincentcuaz Mar 9, 2023
fb86e46
add note in entropic gw solvers
cedricvincentcuaz Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ The contributors to this library are:
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, semi-relaxed FGW)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
Expand Down
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ POT provides the following generic OT solvers (links to examples):
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) [48].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.

POT provides the following Machine Learning related solvers:
Expand Down Expand Up @@ -300,4 +301,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer

[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.

[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations.
[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations.

[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787.

[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
## 0.8.3dev

#### New features

- Added feature to (Fused) Gromov-Wasserstein solvers to handle asymmetric matrices (PR #431)
- Added semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #431)
- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434)
- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434)
- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434)
Expand Down
1 change: 0 additions & 1 deletion examples/gromov/plot_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
==========================
Gromov-Wasserstein example
==========================

This example is designed to show how to use the Gromov-Wassertsein distance
computation in POT.
"""
Expand Down
300 changes: 300 additions & 0 deletions examples/gromov/plot_semirelaxed_fgw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
# -*- coding: utf-8 -*-
"""
==========================
Semi-relaxed (Fused) Gromov-Wasserstein example
==========================

This example is designed to show how to use the semi-relaxed Gromov-Wasserstein
and the semi-relaxed Fused Gromov-Wasserstein divergences.

sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of
G2 at a minimal (F)GW distance from G1.

First, we generate two graphs following Stochastic Block Models, then show
how to compute their srGW matchings and illustrate them. These graphs are then
endowed with node features and we follow the same process with srFGW.

[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2021.
"""

# Author: Cédric Vincent-Cuaz <[email protected]>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

import numpy as np
import matplotlib.pylab as pl
from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein
import networkx as nx
from networkx.generators.community import stochastic_block_model as sbm

# %%
# =============================================================================
# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
# =============================================================================


N2 = 20 # 2 communities
N3 = 30 # 3 communities
p2 = [[1., 0.1],
[0.1, 0.9]]
p3 = [[1., 0.1, 0.],
[0.1, 0.95, 0.1],
[0., 0.1, 0.9]]
G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)


C2 = nx.to_numpy_array(G2)
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved
C3 = nx.to_numpy_array(G3)

h2 = np.ones(C2.shape[0]) / C2.shape[0]
h3 = np.ones(C3.shape[0]) / C3.shape[0]

# Add weights on the edges for visualization later on
weight_intra_G2 = 5
weight_inter_G2 = 0.5
weight_intra_G3 = 1.
weight_inter_G3 = 1.5

weightedG2 = nx.Graph()
part_G2 = [G2.nodes[i]['block'] for i in range(N2)]

for node in G2.nodes():
weightedG2.add_node(node)
for i, j in G2.edges():
if part_G2[i] == part_G2[j]:
weightedG2.add_edge(i, j, weight=weight_intra_G2)
else:
weightedG2.add_edge(i, j, weight=weight_inter_G2)

weightedG3 = nx.Graph()
part_G3 = [G3.nodes[i]['block'] for i in range(N3)]

for node in G3.nodes():
weightedG3.add_node(node)
for i, j in G3.edges():
if part_G3[i] == part_G3[j]:
weightedG3.add_edge(i, j, weight=weight_intra_G3)
else:
weightedG3.add_edge(i, j, weight=weight_inter_G3)
# %%
# =============================================================================
# Compute their semi-relaxed Gromov-Wasserstein divergences
# =============================================================================

# 0) GW(C2, h2, C3, h3) for reference
OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
gw = log['gw_dist']

# 1) srGW(C2, h2, C3)
OT_23, log_23 = semirelaxed_gromov_wasserstein(C2, C3, h2, symmetric=True,
log=True, G0=None)
srgw_23 = log_23['srgw_dist']

# 2) srGW(C3, h3, C2)

OT_32, log_32 = semirelaxed_gromov_wasserstein(C3, C2, h3, symmetric=None,
log=True, G0=OT.T)
srgw_32 = log_32['srgw_dist']

print('GW(C2, C3) = ', gw)
print('srGW(C2, h2, C3) = ', srgw_23)
print('srGW(C3, h3, C2) = ', srgw_32)


# %%
# =============================================================================
# Visualization of the semi-relaxed Gromov-Wasserstein matchings
# =============================================================================

# We color nodes of the graph on the right - then project its node colors
# based on the optimal transport plan from the srGW matching


def draw_graph(G, C, nodes_color_part, Gweights=None,
pos=None, edge_color='black', node_size=None,
shiftx=0, seed=0):

if (pos is None):
pos = nx.spring_layout(G, scale=1., seed=seed)

if shiftx != 0:
for k, v in pos.items():
v[0] = v[0] + shiftx

alpha_edge = 0.7
width_edge = 1.8
if Gweights is None:
nx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color)
else:
# We make more visible connections between activated nodes
n = len(Gweights)
edgelist_activated = []
edgelist_deactivated = []
for i in range(n):
for j in range(n):
if Gweights[i] * Gweights[j] * C[i, j] > 0:
edgelist_activated.append((i, j))
elif C[i, j] > 0:
edgelist_deactivated.append((i, j))

nx.draw_networkx_edges(G, pos, edgelist=edgelist_activated,
width=width_edge, alpha=alpha_edge,
edge_color=edge_color)
nx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated,
width=width_edge, alpha=0.1,
edge_color=edge_color)

if Gweights is None:
for node, node_color in enumerate(nodes_color_part):
nx.draw_networkx_nodes(G, pos, nodelist=[node],
node_size=node_size, alpha=1,
node_color=node_color)
else:
scaled_Gweights = Gweights / (0.5 * Gweights.max())
nodes_size = node_size * scaled_Gweights
for node, node_color in enumerate(nodes_color_part):
nx.draw_networkx_nodes(G, pos, nodelist=[node],
node_size=nodes_size[node], alpha=1,
node_color=node_color)
return pos


def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
p1, p2, T, pos1=None, pos2=None,
shiftx=4, switchx=False, node_size=70,
seed_G1=0, seed_G2=0):
starting_color = 0
# get graphs partition and their coloring
part1 = part_G1.copy()
unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)]
nodes_color_part1 = []
for cluster in part1:
nodes_color_part1.append(unique_colors[cluster])

nodes_color_part2 = []
# T: getting colors assignment from argmin of columns
for i in range(len(G2.nodes())):
j = np.argmax(T[:, i])
nodes_color_part2.append(nodes_color_part1[j])
pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1,
pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1)
pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2,
node_size=node_size, shiftx=shiftx, seed=seed_G2)
for k1, v1 in pos1.items():
for k2, v2 in pos2.items():
if (T[k1, k2] > 0):
pl.plot([pos1[k1][0], pos2[k2][0]],
[pos1[k1][1], pos2[k2][1]],
'-', lw=0.8, alpha=0.5,
color=nodes_color_part1[k1])
return pos1, pos2


node_size = 40
fontsize = 10
seed_G2 = 0
seed_G3 = 4

pl.figure(1, figsize=(8, 2.5))
pl.clf()
pl.subplot(121)
pl.axis('off')
pl.axis
pl.title(r'srGW$(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize)

hbar2 = OT_23.sum(axis=0)
pos1, pos2 = draw_transp_colored_srGW(
weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
pl.subplot(122)
pl.axis('off')
hbar3 = OT_32.sum(axis=0)
pl.title(r'srGW$(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize)
pos1, pos2 = draw_transp_colored_srGW(
weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
pl.tight_layout()

pl.show()

# %%
# =============================================================================
# Add node features
# =============================================================================

# We add node features with given mean - by clusters
# and inversely proportional to clusters' intra-connectivity

F2 = np.zeros((N2, 1))
for i, c in enumerate(part_G2):
F2[i, 0] = np.random.normal(loc=c, scale=0.01)

F3 = np.zeros((N3, 1))
for i, c in enumerate(part_G3):
F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)

# %%
# =============================================================================
# Compute their semi-relaxed Fused Gromov-Wasserstein divergences
# =============================================================================

alpha = 0.5
# Compute pairwise euclidean distance between node features
M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T)

# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference

OT, log = fused_gromov_wasserstein(
M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True)
fgw = log['fgw_dist']

# 1) srFGW(C2, F2, h2, C3, F3)
OT_23, log_23 = semirelaxed_fused_gromov_wasserstein(
M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None)
srfgw_23 = log_23['srfgw_dist']

# 2) srFGW(C3, F3, h3, C2, F2)

OT_32, log_32 = semirelaxed_fused_gromov_wasserstein(
M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None)
srfgw_32 = log_32['srfgw_dist']

print('FGW(C2, F2, C3, F3) = ', fgw)
print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23)
print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32)

# %%
# =============================================================================
# Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings
# =============================================================================

# We color nodes of the graph on the right - then project its node colors
# based on the optimal transport plan from the srFGW matching
# NB: colors refer to clusters - not to node features

pl.figure(2, figsize=(8, 2.5))
pl.clf()
pl.subplot(121)
pl.axis('off')
pl.axis
pl.title(r'srFGW$(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize)

hbar2 = OT_23.sum(axis=0)
pos1, pos2 = draw_transp_colored_srGW(
weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
pl.subplot(122)
pl.axis('off')
hbar3 = OT_32.sum(axis=0)
pl.title(r'srFGW$(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize)
pos1, pos2 = draw_transp_colored_srGW(
weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
pl.tight_layout()

pl.show()
Loading