Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] quantized gromov wasserstein solver #603

Merged
merged 26 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fa787c2
first commit : quantized gromov wasserstein solver
cedricvincentcuaz Feb 11, 2024
ce25f93
start setting up tests
cedricvincentcuaz Feb 11, 2024
e66dce4
fix build OT for all backends - nb: concatenation procedure is less e…
cedricvincentcuaz Feb 12, 2024
5c0e368
dealing with edge cases
cedricvincentcuaz Feb 12, 2024
5075488
fix pep8
cedricvincentcuaz Feb 12, 2024
9adfa21
Merge branch 'master' of https://github.com/cedricvincentcuaz/POT int…
cedricvincentcuaz Feb 24, 2024
8be5dfb
updates + start setting exemple
cedricvincentcuaz Feb 24, 2024
096eb1a
updates + start setting exemple
cedricvincentcuaz Feb 24, 2024
3b99d08
updating code + exemple + test + docs
cedricvincentcuaz Feb 25, 2024
6e0322a
fix sklearn imports
cedricvincentcuaz Feb 25, 2024
94e4d23
fix
cedricvincentcuaz Feb 25, 2024
bfd2cd8
Merge branch 'master' into quantized
rflamary Feb 29, 2024
331bacf
Merge branch 'master' into quantized
cedricvincentcuaz Mar 4, 2024
9c660c3
merge with master
cedricvincentcuaz Apr 24, 2024
f453a06
setting up new API for qGW
cedricvincentcuaz Apr 25, 2024
1b46bbb
fix pep8
cedricvincentcuaz Apr 25, 2024
876a7c5
tests
cedricvincentcuaz Apr 26, 2024
0bf49d8
merge with master
cedricvincentcuaz May 27, 2024
fc0bfbb
Merge branch 'master' into quantized
rflamary May 27, 2024
604acff
update qFGW plots
cedricvincentcuaz May 28, 2024
7834a44
update qFGW plots
cedricvincentcuaz May 28, 2024
8b9bd8b
up tests
cedricvincentcuaz May 28, 2024
d83c426
update example
cedricvincentcuaz May 28, 2024
4965c90
merge
cedricvincentcuaz May 29, 2024
bb1b342
merge master
cedricvincentcuaz May 29, 2024
096955a
complete tests
cedricvincentcuaz May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ The contributors to this library are:
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW, quantized FGW)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ POT provides the following generic OT solvers (links to examples):
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
* [Quantized Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) [66].
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
* Smooth Strongly Convex Nearest Brenier Potentials [58], with an extension to bounding potentials using [59].
Expand Down Expand Up @@ -355,3 +356,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.

[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).

[66] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing.
4 changes: 4 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#### New features
+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster.

#### New features
- Add implicit sinkhorn gradients in `ot.solve` and `ot.solve_sample` (PR #605)
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved
- New quantized GW solvers `ot.gromov.quantized_gromov_wasserstein` and `ot.gromov.quantized_gromov_wasserstein_partitioned` (PR #603)

#### Closed issues
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
Expand Down
267 changes: 267 additions & 0 deletions examples/gromov/plot_quantized_gromov_wasserstein.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# -*- coding: utf-8 -*-
"""
===============================================
Quantized Gromov-Wasserstein example
===============================================

This example is designed to show how to use the quantized Gromov-Wasserstein
solvers [66]. POT provides a wrapper `quantized_gromov_wasserstein` operating other
graphs, and a generic solver `quantized_gromov_wasserstein_partitioned` that allows
the user to precompute any partitioning and representant selection methods.

We generate two graphs following Stochastic Block Models encoded as shortest path
matrices. Then show how to compute their quantized gromov-wasserstein
matchings using both solvers.

[66] Chowdhury, S., Miller, D., & Needham, T. (2021). Quantized gromov-wasserstein.
ECML PKDD 2021. Springer International Publishing.
"""

# Author: Cédric Vincent-Cuaz <[email protected]>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

import numpy as np
import matplotlib.pylab as pl
import networkx
from networkx.generators.community import stochastic_block_model as sbm
from scipy.sparse.csgraph import shortest_path

from ot.gromov import quantized_gromov_wasserstein, quantized_gromov_wasserstein_partitioned
from ot.gromov._quantized import _get_partition, _get_representants, _formate_partitioned_graph

#############################################################################
#
# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
# --------------------------------------------------------------------------


N2 = 30 # 2 communities
N3 = 45 # 3 communities
p2 = [[0.8, 0.1],
[0.1, 0.7]]
p3 = [[0.8, 0.1, 0.],
[0.1, 0.75, 0.1],
[0., 0.1, 0.7]]
G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)


C2 = networkx.to_numpy_array(G2)
C3 = networkx.to_numpy_array(G3)

spC2 = shortest_path(C2)
spC3 = shortest_path(C3)

h2 = np.ones(C2.shape[0]) / C2.shape[0]
h3 = np.ones(C3.shape[0]) / C3.shape[0]

# Add weights on the edges for visualization later on
weight_intra_G2 = 5
weight_inter_G2 = 0.5
weight_intra_G3 = 1.
weight_inter_G3 = 1.5

weightedG2 = networkx.Graph()
part_G2 = [G2.nodes[i]['block'] for i in range(N2)]

for node in G2.nodes():
weightedG2.add_node(node)
for i, j in G2.edges():
if part_G2[i] == part_G2[j]:
weightedG2.add_edge(i, j, weight=weight_intra_G2)
else:
weightedG2.add_edge(i, j, weight=weight_inter_G2)

weightedG3 = networkx.Graph()
part_G3 = [G3.nodes[i]['block'] for i in range(N3)]

for node in G3.nodes():
weightedG3.add_node(node)
for i, j in G3.edges():
if part_G3[i] == part_G3[j]:
weightedG3.add_edge(i, j, weight=weight_intra_G3)
else:
weightedG3.add_edge(i, j, weight=weight_inter_G3)

#############################################################################
#
# Compute their quantized Gromov-Wasserstein distance using the wrapper
# ---------------------------------------------------------

# 0) qGW(spC2, h2, spC3, h3) while partitioning the adjacency matrices C2 and C3
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved
# in 2 and 3 clusters respectively, using the Fluid algorithm and selecting
# representant in each partition using maximal pagerank.
# Notice that C2 and C3 are optional and if they are not specified these
# pre-processing algorithms will be applied to spC2 and spC3.

part_method = 'louvain'
rep_method = 'pagerank'
OT_global, OTs_local, OT, log = quantized_gromov_wasserstein(
spC2, spC3, 2, 3, C2, C3, h2, h3, part_method=part_method,
rep_method=rep_method, log=True)

qGW_dist = log['qGW_dist']
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved

#############################################################################
#
# Compute their quantized Gromov-Wasserstein distance using any partitioning and representant selection methods
# ---------------------------------------------------------

# 1-a) Partition C2 and C3 in 2 and 3 clusters respectively using the Fluid
# algorithm implementation from networkx. Encode these partitions via vectors of assignments.

part2 = _get_partition(C2, npart=2, part_method=part_method)
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved
part3 = _get_partition(C3, npart=3, part_method=part_method)

# 1-b) Select representant in each partition using the Pagerank algorithm
# implementation from networkx.

rep_indices2 = _get_representants(C2, part2, rep_method=rep_method)
rep_indices3 = _get_representants(C3, part3, rep_method=rep_method)

# 1-c) Formate partitions. CR (2, 2) relations between representants in each space.
# list_R contain relations between samples and representants within each partition.
# list_h contain samples relative importance within each partition.

CR2, list_R2, list_h2 = _formate_partitioned_graph(spC2, h2, part2, rep_indices2)
CR3, list_R3, list_h3 = _formate_partitioned_graph(spC3, h3, part3, rep_indices3)
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved

# 1-d) call to partitioned quantized gromov-wasserstein solver

OT_global_, OTs_local_, OT_, log_ = quantized_gromov_wasserstein_partitioned(
CR2, CR3, list_R2, list_R3, list_h2, list_h3, build_OT=True, log=True)


#############################################################################
#
# Visualization of the quantized Gromov-Wasserstein matching
# --------------------------------------------------------------
#
# We color nodes of the graph based on the respective partition of each graph.
# On the first plot we illustrate the qGW matching between both shortest path matrices.
# While the GW matching across representants of each space is illustrated on the right.


def draw_graph(G, C, nodes_color_part, rep_indices, pos=None,
edge_color='black', alpha_edge=0.7, node_size=None,
shiftx=0, seed=0, highlight_rep=False):

if (pos is None):
pos = networkx.spring_layout(G, scale=1., seed=seed)

if shiftx != 0:
for k, v in pos.items():
v[0] = v[0] + shiftx

width_edge = 1.5

if not highlight_rep:
networkx.draw_networkx_edges(
G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color)
else:
for edge in G.edges:
if (edge[0] in rep_indices) and (edge[1] in rep_indices):
networkx.draw_networkx_edges(
G, pos, edgelist=[edge], width=width_edge, alpha=alpha_edge,
edge_color=edge_color)
else:
networkx.draw_networkx_edges(
G, pos, edgelist=[edge], width=width_edge, alpha=0.2,
edge_color=edge_color)

for node, node_color in enumerate(nodes_color_part):
local_node_shape, local_node_size = 'o', node_size
if node in rep_indices:
local_node_shape, local_node_size = '*', 6 * node_size

alpha = 0.9
if highlight_rep:
alpha = 0.9 if node in rep_indices else 0.2

networkx.draw_networkx_nodes(G, pos, nodelist=[node], alpha=alpha,
node_shape=local_node_shape,
node_size=local_node_size,
node_color=node_color)

return pos


def draw_transp_colored_qGW(
G1, C1, G2, C2, part1, part2, rep_indices1, rep_indices2, T,
pos1=None, pos2=None, shiftx=4, switchx=False, node_size=70,
seed_G1=0, seed_G2=0, highlight_rep=False):
starting_color = 0
# get graphs partition and their coloring
unique_colors1 = ['C%s' % (starting_color + i) for i in np.unique(part1)]
nodes_color_part1 = []
for cluster in part1:
nodes_color_part1.append(unique_colors1[cluster])

starting_color = len(unique_colors1) + 1
unique_colors2 = ['C%s' % (starting_color + i) for i in np.unique(part2)]
nodes_color_part2 = []
for cluster in part2:
nodes_color_part2.append(unique_colors2[cluster])

pos1 = draw_graph(
G1, C1, nodes_color_part1, rep_indices1, pos=pos1, node_size=node_size,
shiftx=0, seed=seed_G1, highlight_rep=highlight_rep)
pos2 = draw_graph(
G2, C2, nodes_color_part2, rep_indices2, pos=pos2, node_size=node_size,
shiftx=shiftx, seed=seed_G2, highlight_rep=highlight_rep)

if not highlight_rep:
for k1, v1 in pos1.items():
max_Tk1 = np.max(T[k1, :])
for k2, v2 in pos2.items():
if (T[k1, k2] > 0):
pl.plot([pos1[k1][0], pos2[k2][0]],
[pos1[k1][1], pos2[k2][1]],
'-', lw=0.7, alpha=T[k1, k2] / max_Tk1,
color=nodes_color_part1[k1])

else: # OT is only between representants
for id1, node_id1 in enumerate(rep_indices1):
max_Tk1 = np.max(T[id1, :])
for id2, node_id2 in enumerate(rep_indices2):
if (T[id1, id2] > 0):
pl.plot([pos1[node_id1][0], pos2[node_id2][0]],
[pos1[node_id1][1], pos2[node_id2][1]],
'-', lw=0.8, alpha=T[id1, id2] / max_Tk1,
color=nodes_color_part1[node_id1])
return pos1, pos2


node_size = 40
fontsize = 10
seed_G2 = 0
seed_G3 = 3

part2_ = part2.astype(np.int32)
part3_ = part3.astype(np.int32)

pl.figure(1, figsize=(8, 3))
pl.clf()
pl.axis('off')
pl.subplot(1, 2, 1)
pl.title(r'qGW$(\mathbf{spC_2}, \mathbf{spC_3}) =%s$' % (np.round(qGW_dist, 3)), fontsize=fontsize)

pos1, pos2 = draw_transp_colored_qGW(
weightedG2, C2, weightedG3, C3, part2_, part3_, rep_indices2, rep_indices3,
T=OT_, shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)

pl.tight_layout()

pl.subplot(1, 2, 2)
pl.title(r' GW$(\mathbf{CR_2}, \mathbf{CR_3}) =%s$' % (np.round(log_['gw_dist_CR'], 3)), fontsize=fontsize)

pos1, pos2 = draw_transp_colored_qGW(
weightedG2, C2, weightedG3, C3, part2_, part3_, rep_indices2, rep_indices3,
T=OT_global, shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3,
highlight_rep=True)

pl.tight_layout()
pl.show()
5 changes: 4 additions & 1 deletion ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
fused_gromov_wasserstein_dictionary_learning,
fused_gromov_wasserstein_linear_unmixing)

from ._quantized import (quantized_gromov_wasserstein,
quantized_gromov_wasserstein_partitioned)

__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss',
'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed',
Expand All @@ -64,4 +66,5 @@
'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein',
'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning',
'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning',
'fused_gromov_wasserstein_linear_unmixing']
'fused_gromov_wasserstein_linear_unmixing',
'quantized_gromov_wasserstein', 'quantized_gromov_wasserstein_partitioned']
Loading
Loading