From 791137b3f49b1aacadb4062cef97963915a5fa32 Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Mon, 14 Oct 2024 19:29:05 +0200 Subject: [PATCH] [WIP] Implementation of FUGW and UCOOT (#677) * implementation of FUGW and UCOOT * fix pep8 error * fix test_utils error * remove print * add documentation and fix bug * first code review * fix documentation --- README.md | 17 +- RELEASES.md | 25 +- ...etection_with_COOT_and_unbalanced_COOT.py} | 72 +- ot/gromov/__init__.py | 16 +- ot/gromov/_unbalanced.py | 1080 +++++++++++++++++ ot/gromov/_utils.py | 356 ++++++ ot/solvers.py | 129 +- ot/unbalanced/_lbfgs.py | 2 +- ot/unbalanced/_sinkhorn.py | 2 +- ot/utils.py | 52 +- test/gromov/test_fugw.py | 685 +++++++++++ test/gromov/test_utils.py | 37 + test/test_solvers.py | 21 +- test/test_ucoot.py | 795 ++++++++++++ test/test_utils.py | 9 +- 15 files changed, 3180 insertions(+), 118 deletions(-) rename examples/others/{plot_learning_weights_with_COOT.py => plot_outlier_detection_with_COOT_and_unbalanced_COOT.py} (59%) create mode 100644 ot/gromov/_unbalanced.py create mode 100644 test/gromov/test_fugw.py create mode 100644 test/test_ucoot.py diff --git a/README.md b/README.md index fbad3086e..fa9f5789e 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,9 @@ POT provides the following generic OT solvers (links to examples): * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. * [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59]. * Gaussian Mixture Model OT [69] +* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and +[unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. +* Fused unbalanced Gromov-Wasserstein [70]. POT provides the following Machine Learning related solvers: @@ -62,7 +65,7 @@ POT provides the following Machine Learning related solvers: * [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) [14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) [8]. * [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) [11] (requires autograd + pymanopt). * [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) [27]. -* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53] +* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53] Some other examples are available in the [documentation](https://pythonot.github.io/auto_examples/index.html). @@ -198,7 +201,7 @@ This toolbox has been created by * [Rémi Flamary](https://remi.flamary.com/) * [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/) -It is currently maintained by +It is currently maintained by * [Rémi Flamary](https://remi.flamary.com/) * [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/) @@ -370,4 +373,12 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing. -[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. \ No newline at end of file +[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. + +[70] A. Thual, H. Tran, T. Zemskova, N. Courty, R. Flamary, S. Dehaene +& B. Thirion (2022). [Aligning individual brains with Fused Unbalanced Gromov-Wasserstein.](https://proceedings.neurips.cc/paper_files/paper/2022/file/8906cac4ca58dcaf17e97a0486ad57ca-Paper-Conference.pdf). Neural Information Processing Systems (NeurIPS). + +[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on +Artificial Intelligence. + +[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS). \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 277af7847..0b9a24452 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,15 +3,16 @@ ## 0.9.5dev #### New features -- Add feature `mass=True` for `nx.kl_div` (PR #654) -- Gaussian Mixture Model OT `ot.gmm` (PR #649) -- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659) -- Add initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659) +- Added feature `mass=True` for `nx.kl_div` (PR #654) +- Implemented Gaussian Mixture Model OT `ot.gmm` (PR #649) +- Added feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659) +- Added initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659) - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) -- restructure `ot.unbalanced` module (PR #658) -- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) +- Restructured `ot.unbalanced` module (PR #658) +- Added `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658) +- Implemented Fused unbalanced Gromov-Wasserstein and unbalanced Co-Optimal Transport (PR #677) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) @@ -72,7 +73,7 @@ xs, xt = np.random.randn(100, 2), np.random.randn(50, 2) # Solve OT problem with empirical samples sol = ot.solve_sample(xs, xt) # Exact OT betwen smaples with uniform weights -sol = ot.solve_sample(xs, xt, wa, wb) # Exact OT with weights given by user +sol = ot.solve_sample(xs, xt, wa, wb) # Exact OT with weights given by user sol = ot.solve_sample(xs, xt, reg= 1, metric='euclidean') # sinkhorn with euclidean metric @@ -84,7 +85,7 @@ sol = ot.solve_sample(x,x2, method='lowrank', rank=10) # compute lowrank sinkhor value_bw = ot.solve_sample(xs, xt, method='gaussian').value # Bures-Wasserstein distance -# Solve GW problem +# Solve GW problem Cs, Ct = ot.dist(xs, xs), ot.dist(xt, xt) # compute cost matrices sol = ot.solve_gromov(Cs,Ct) # Exact GW between samples with uniform weights @@ -92,7 +93,7 @@ sol = ot.solve_gromov(Cs,Ct) # Exact GW between samples with uniform weights M = ot.dist(xs, xt) # compute cost matrix # Exact FGW between samples with uniform weights -sol = ot.solve_gromov(Cs, Ct, M, loss='KL', alpha=0.7) # FGW with KL data fitting +sol = ot.solve_gromov(Cs, Ct, M, loss='KL', alpha=0.7) # FGW with KL data fitting # recover solutions objects @@ -102,14 +103,14 @@ value = sol.value # OT value # for GW and FGW value_linear = sol.value_linear # linear part of the loss -value_quad = sol.value_quad # quadratic part of the loss +value_quad = sol.value_quad # quadratic part of the loss ``` Users are encouraged to use the new API (it is much simpler) but it might still be subjects to small changes before the release of POT 1.0 . -We also fixed a number of issues, the most pressing being a problem of GPU memory allocation when pytorch is installed that will not happen now thanks to Lazy initialization of the backends. We now also have the possibility to deactivate some backends using environment which prevents POT from importing them and can lead to large import speedup. +We also fixed a number of issues, the most pressing being a problem of GPU memory allocation when pytorch is installed that will not happen now thanks to Lazy initialization of the backends. We now also have the possibility to deactivate some backends using environment which prevents POT from importing them and can lead to large import speedup. #### New features @@ -143,7 +144,7 @@ We also fixed a number of issues, the most pressing being a problem of GPU memor - Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566) - Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572) - Create `ot/bregman/`repository (Issue #567, PR #569) -- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573) +- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573) - Fix (fused) gromov-wasserstein barycenter solvers to support `kl_loss`(PR #576) diff --git a/examples/others/plot_learning_weights_with_COOT.py b/examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.py similarity index 59% rename from examples/others/plot_learning_weights_with_COOT.py rename to examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.py index cb115c306..e1f48f724 100644 --- a/examples/others/plot_learning_weights_with_COOT.py +++ b/examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.py @@ -1,11 +1,16 @@ # -*- coding: utf-8 -*- r""" -=============================================================== -Learning sample marginal distribution with CO-Optimal Transport -=============================================================== +====================================================================================================================================== +Detecting outliers by learning sample marginal distribution with CO-Optimal Transport and by using unbalanced Co-Optimal Transport +====================================================================================================================================== -In this example, we illustrate how to estimate the sample marginal distribution which minimizes -the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data +In this example, we consider two point clouds living in different Euclidean spaces, where the outliers +are artifically injected into the target data. We illustrate two methods which allow to filter out +these outliers. + +The first method requires learning the sample marginal distribution which minimizes +the CO-Optimal Transport distance [49] between two input spaces. +More precisely, given a source data :math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem @@ -17,9 +22,19 @@ allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2` with differentiable losses. +The second method simply requires direct application of unbalanced Co-Optimal Transport [71]. +More precisely, it is enough to use the sample and feature coupling from solving + +.. math:: + \text{UCOOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right) + +where all the marginal distributions are uniform. + .. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). `CO-Optimal Transport `_. Advances in Neural Information Processing Systems, 33. +.. [71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). + AAAI Conference on Artificial Intelligence. """ # Author: Remi Flamary @@ -35,6 +50,7 @@ from ot.coot import co_optimal_transport as coot from ot.coot import co_optimal_transport2 as coot2 +from ot.gromov._unbalanced import unbalanced_co_optimal_transport # %% @@ -148,3 +164,49 @@ con = ConnectionPatch( xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") fig.add_artist(con) + +# %% +# Now, let see if we can use unbalanced Co-Optimal Transport to recover the clean OT plans, +# without the need of learning the marginal distribution as in Co-Optimal Transport. +# ----------------------------------------------------------------------------------------- + +pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=X, Y=Y_noisy, reg_marginals=(10, 10), epsilon=0, divergence="kl", + unbalanced_solver="mm", max_iter=1000, tol=1e-6, + max_iter_ot=1000, tol_ot=1e-6, log=False, verbose=False +) + +# %% +# Visualizing the row and column alignments learned by unbalanced Co-Optimal Transport. +# ----------------------------------------------------------------------------------------- +# +# Similar to Co-Optimal Transport, we are also be able to fully recover the clean OT plans. + +fig = pl.figure(4, (9, 7)) +pl.clf() + +ax1 = pl.subplot(2, 2, 3) +pl.imshow(X, vmin=-2, vmax=2) +pl.xlabel('$X$') + +ax2 = pl.subplot(2, 2, 2) +ax2.yaxis.tick_right() +pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2) +pl.title("Transpose(Noisy $Y$)") +ax2.xaxis.tick_top() + +for i in range(n1): + j = np.argmax(pi_sample[i, :]) + xyA = (d1 - .5, i) + xyB = (j, d2 - .5) + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, + coordsB=ax2.transData, color="black") + fig.add_artist(con) + +for i in range(d1): + j = np.argmax(pi_feature[i, :]) + xyA = (i, -.5) + xyB = (-.5, j) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + fig.add_artist(con) diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 5cf19784b..6d3f56d8b 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -6,13 +6,16 @@ # Author: Remi Flamary # Cedric Vincent-Cuaz +# Quang Huy Tran # # License: MIT License # All submodules and packages from ._utils import (init_matrix, tensor_product, gwloss, gwggrad, init_matrix_semirelaxed, semirelaxed_init_plan, - update_barycenter_structure, update_barycenter_feature) + update_barycenter_structure, update_barycenter_feature, + div_between_product, div_to_product, fused_unbalanced_across_spaces_cost, + uot_cost_matrix, uot_parameters_and_measures) from ._gw import (gromov_wasserstein, gromov_wasserstein2, fused_gromov_wasserstein, fused_gromov_wasserstein2, @@ -63,9 +66,17 @@ quantized_fused_gromov_wasserstein_samples ) +from ._unbalanced import (fused_unbalanced_gromov_wasserstein, + fused_unbalanced_gromov_wasserstein2, + unbalanced_co_optimal_transport, + unbalanced_co_optimal_transport2, + fused_unbalanced_across_spaces_divergence) + __all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'init_matrix_semirelaxed', 'semirelaxed_init_plan', 'update_barycenter_structure', 'update_barycenter_feature', + 'div_between_product', 'div_to_product', 'fused_unbalanced_across_spaces_cost', + 'uot_cost_matrix', 'uot_parameters_and_measures', 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters', 'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2', @@ -87,4 +98,7 @@ 'get_graph_representants', 'format_partitioned_graph', 'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples', 'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples', + 'fused_unbalanced_gromov_wasserstein', 'fused_unbalanced_gromov_wasserstein2', + 'unbalanced_co_optimal_transport', 'unbalanced_co_optimal_transport2', + 'fused_unbalanced_across_spaces_divergence' ] diff --git a/ot/gromov/_unbalanced.py b/ot/gromov/_unbalanced.py new file mode 100644 index 000000000..cc7b9e53c --- /dev/null +++ b/ot/gromov/_unbalanced.py @@ -0,0 +1,1080 @@ +# -*- coding: utf-8 -*- +""" +Unbalanced Co-Optimal Transport and Fused Unbalanced Gromov-Wasserstein solvers +""" + +# Author: Quang Huy Tran +# Alexis Thual +# +# License: MIT License + +import warnings +from functools import partial +import ot +from ot.backend import get_backend +from ot.utils import list_to_array, get_parameter_pair +from ._utils import fused_unbalanced_across_spaces_cost, uot_cost_matrix, uot_parameters_and_measures + + +def fused_unbalanced_across_spaces_divergence( + X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, + reg_marginals=10, epsilon=0, reg_type="joint", divergence="kl", + unbalanced_solver="sinkhorn", alpha=0, M_samp=None, M_feat=None, + rescale_plan=True, init_pi=None, init_duals=None, max_iter=100, + tol=1e-7, max_iter_ot=500, tol_ot=1e-7, log=False, verbose=False, + **kwargs_solver): + + r"""Compute the fused unbalanced cross-spaces divergence between two matrices equipped + with the distributions on rows and columns. We consider two cases of matrix: + + - (Squared) similarity matrix in Gromov-Wasserstein setting, + whose rows and columns represent the samples. + + - Arbitrary-size matrix in Co-Optimal Transport setting, + whose rows represent samples, and columns represent corresponding features/dimensions. + + More precisely, this function returns the sample and feature transport plans between + :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and + :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`, + by solving the following problem using Block Coordinate Descent algorithm: + + .. math:: + + \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}} + &\quad \sum_{i,j,k,l} + (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} \\ + &+ \rho_s \mathbf{Div}(\mathbf{P}_{\# 1} \mathbf{Q}_{\# 1}^T | \mathbf{w}_{xs} \mathbf{w}_{ys}^T) + + \rho_f \mathbf{Div}(\mathbf{P}_{\# 2} \mathbf{Q}_{\# 2}^T | \mathbf{w}_{xf} \mathbf{w}_{yf}^T) \\ + &+ \alpha_s \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} + + \alpha_f \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l} + + \mathbf{Reg}(\mathbf{P}, \mathbf{Q}) + + Where: + + - :math:`\mathbf{X}`: Source input (arbitrary-size) matrix + - :math:`\mathbf{Y}`: Target input (arbitrary-size) matrix + - :math:`\mathbf{M^{(s)}}`: Additional sample matrix + - :math:`\mathbf{M^{(f)}}`: Additional feature matrix + - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space + - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space + - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space + - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space + - :math:`\mathbf{Div}`: Either Kullback-Leibler divergence or half-squared L2 norm. + - :math:`\mathbf{Reg}`: Regularizer for sample and feature couplings. + + We consider two types of regularizer: + + Independent regularization used in unbalanced Co-Optimal Transport + + .. math:: + \mathbf{Reg}(\mathbf{P}, \mathbf{Q}) = + \varepsilon_s \mathbf{Div}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T) + + \varepsilon_f \mathbf{Div}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T) + + + Joint regularization used in fused unbalanced Gromov-Wasserstein + + .. math:: + \mathbf{Reg}(\mathbf{P}, \mathbf{Q}) = + \varepsilon \mathbf{Div}(\mathbf{P} \otimes \mathbf{Q} | (\mathbf{w}_{xs} \mathbf{w}_{ys}^T) \otimes (\mathbf{w}_{xf} \mathbf{w}_{yf}^T) ) + + .. note:: This function allows epsilon to be zero. In that case, `unbalanced_method` must be either "mm" or "lbfgsb". + + Parameters + ---------- + X : (n_sample_x, n_feature_x) array-like, float + Source input matrix. + Y : (n_sample_y, n_feature_y) array-like, float + Target input matrix. + wx_samp : (n_sample_x, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix X. + Uniform distribution by default. + wx_feat : (n_feature_x, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix X. + Uniform distribution by default. + wy_samp : (n_sample_y, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix Y. + Uniform distribution by default. + wy_feat : (n_feature_y, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix Y. + Uniform distribution by default. + reg_marginals: float or indexable object of length 1 or 2 + Marginal relaxation terms for sample and feature couplings. + If `reg_marginals` is a scalar or an indexable object of length 1, + then the same value is applied to both marginal relaxations. + epsilon : scalar or indexable object of length 2, float or int, optional (default = 0) + Regularization parameters for entropic approximation of sample and feature couplings. + Allow the case where `epsilon` contains 0. In that case, the MM solver is used by default + instead of Sinkhorn solver. If `epsilon` is scalar, then the same value is applied to + both regularization of sample and feature couplings. + reg_type: string, optional + + - If `reg_type` = "joint": then use joint regularization for couplings. + + - If `reg_type` = "indepedent": then use independent regularization for couplings. + divergence : string, optional (default = "kl") + + - If `divergence` = "kl", then Div is the Kullback-Leibler divergence. + + - If `divergence` = "l2", then Div is the half squared Euclidean norm. + unbalanced_solver : string, optional (default = "sinkhorn") + Solver for the unbalanced OT subroutine. + + - If `divergence` = "kl", then `unbalanced_solver` can be: "sinkhorn", "sinkhorn_log", "mm", "lbfgsb" + + - If `divergence` = "l2", then `unbalanced_solver` can be "mm", "lbfgsb" + alpha : scalar or indexable object of length 2, float or int, optional (default = 0) + Coeffficient parameter of linear terms with respect to the sample and feature couplings. + If alpha is scalar, then the same alpha is applied to both linear terms. + M_samp : (n_sample_x, n_sample_y), float, optional (default = None) + Sample matrix associated to the Wasserstein linear term on sample coupling. + M_feat : (n_feature_x, n_feature_y), float, optional (default = None) + Feature matrix associated to the Wasserstein linear term on feature coupling. + rescale_plan : boolean, optional (default = True) + If True, then rescale the sample and feature transport plans within each BCD iteration, + so that they always have equal mass. + init_pi : tuple of two matrices of size (n_sample_x, n_sample_y) and + (n_feature_x, n_feature_y), optional (default = None). + Initialization of sample and feature couplings. + Uniform distributions by default. + init_duals : tuple of two tuples ((n_sample_x, ), (n_sample_y, )) and ((n_feature_x, ), (n_feature_y, )), optional (default = None). + Initialization of sample and feature dual vectors + if using Sinkhorn algorithm. Zero vectors by default. + max_iter : int, optional (default = 100) + Number of Block Coordinate Descent (BCD) iterations. + tol : float, optional (default = 1e-7) + Tolerance of BCD scheme. If the L1-norm between the current and previous + sample couplings is under this threshold, then stop BCD scheme. + max_iter_ot : int, optional (default = 100) + Number of iterations to solve each of the + two unbalanced optimal transport problems in each BCD iteration. + tol_ot : float, optional (default = 1e-7) + Tolerance of unbalanced solver for each of the + two unbalanced optimal transport problems in each BCD iteration. + log : bool, optional (default = False) + If True then the cost and four dual vectors, including + two from sample and two from feature couplings, are recorded. + verbose : bool, optional (default = False) + If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration. + + Returns + ------- + pi_samp : (n_sample_x, n_sample_y) array-like, float + Sample coupling matrix. + pi_feat : (n_feature_x, n_feature_y) array-like, float + Feature coupling matrix. + log : dictionary, optional + Returned if `log` is True. The keys are: + + error : array-like, float + list of L1 norms between the current and previous sample coupling. + duals_sample : (n_sample_x, n_sample_y) tuple, float + Pair of dual vectors when solving OT problem w.r.t the sample coupling. + duals_feature : (n_feature_x, n_feature_y) tuple, float + Pair of dual vectors when solving OT problem w.r.t the feature coupling. + linear : float + Linear part of the cost. + ucoot : float + Total cost. + backend + The proper backend for all input arrays + """ + + # MAIN FUNCTION + + if reg_type not in ["joint", "independent"]: + raise (NotImplementedError('Unknown reg_type="{}"'.format(reg_type))) + if divergence not in ["kl", "l2"]: + raise (NotImplementedError('Unknown divergence="{}"'.format(divergence))) + if unbalanced_solver not in ["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"]: + raise (NotImplementedError('Unknown method="{}"'.format(unbalanced_solver))) + + # hyperparameters + alpha_samp, alpha_feat = get_parameter_pair(alpha) + rho_x, rho_y = get_parameter_pair(reg_marginals) + eps_samp, eps_feat = get_parameter_pair(epsilon) + + if reg_type == "joint": # same regularization + eps_feat = eps_samp + if unbalanced_solver in ["sinkhorn", "sinkhorn_log"] and divergence == "l2": + warnings.warn("Sinkhorn algorithm does not support L2 norm. \ + Divergence is set to 'kl'.") + divergence = "kl" + if unbalanced_solver in ["sinkhorn", "sinkhorn_log"] and (eps_samp == 0 or eps_feat == 0): + warnings.warn("Sinkhorn algorithm does not support unregularized problem. \ + Solver is set to 'mm'.") + unbalanced_solver = "mm" + + if init_pi is None: + pi_samp, pi_feat = None, None + else: + pi_samp, pi_feat = init_pi + + if init_duals is None: + init_duals = (None, None) + duals_samp, duals_feat = init_duals + + arr = [X, Y] + + for tuple in [duals_samp, duals_feat]: + if tuple is not None: + d1, d2 = duals_feat + if d1 is not None: + arr.append(list_to_array(d1)) + if d2 is not None: + arr.append(list_to_array(d2)) + + nx = get_backend(*arr, wx_samp, wx_feat, wy_samp, wy_feat, M_samp, M_feat, pi_samp, pi_feat) + + # constant input variables + if M_samp is None: + if alpha_samp > 0: + warnings.warn("M_samp is None but alpha_samp = {} > 0. \ + The algo will treat as if alpha_samp = 0.".format(alpha_samp)) + else: + M_samp = alpha_samp * M_samp + + if M_feat is None: + if alpha_feat > 0: + warnings.warn("M_feat is None but alpha_feat = {} > 0. \ + The algo will treat as if alpha_feat = 0.".format(alpha_feat)) + else: + M_feat = alpha_feat * M_feat + + nx_samp, nx_feat = X.shape + ny_samp, ny_feat = Y.shape + + # measures on rows and columns + if wx_samp is None: + wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp + if wx_feat is None: + wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat + if wy_samp is None: + wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp + if wy_feat is None: + wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat + wxy_samp = wx_samp[:, None] * wy_samp[None, :] + wxy_feat = wx_feat[:, None] * wy_feat[None, :] + + # initialize coupling and dual vectors + pi_samp = wxy_samp if pi_samp is None else pi_samp + pi_feat = wxy_feat if pi_feat is None else pi_feat + + if unbalanced_solver in ["sinkhorn", "sinkhorn_log"]: + if duals_samp is None: + duals_samp = (nx.zeros(nx_samp, type_as=X), + nx.zeros(ny_samp, type_as=Y)) + if duals_feat is None: + duals_feat = (nx.zeros(nx_feat, type_as=X), + nx.zeros(ny_feat, type_as=Y)) + + # shortcut functions + X_sqr, Y_sqr = X**2, Y**2 + local_cost_samp = partial(uot_cost_matrix, + data=(X_sqr, Y_sqr, X, Y, M_samp), + tuple_p=(wx_feat, wy_feat), + hyperparams=(rho_x, rho_y, eps_feat), + divergence=divergence, + reg_type=reg_type, + nx=nx) + + local_cost_feat = partial(uot_cost_matrix, + data=(X_sqr.T, Y_sqr.T, X.T, Y.T, M_feat), + tuple_p=(wx_samp, wy_samp), + hyperparams=(rho_x, rho_y, eps_samp), + divergence=divergence, + reg_type=reg_type, + nx=nx) + + parameters_uot_l2_samp = partial( + uot_parameters_and_measures, + tuple_weights=(wx_samp, wy_samp, wxy_samp), + hyperparams=(rho_x, rho_y, eps_samp), + reg_type=reg_type, + divergence=divergence, + nx=nx + ) + + parameters_uot_l2_feat = partial( + uot_parameters_and_measures, + tuple_weights=(wx_feat, wy_feat, wxy_feat), + hyperparams=(rho_x, rho_y, eps_feat), + reg_type=reg_type, + divergence=divergence, + nx=nx + ) + + solver = partial( + ot.solve, + reg_type=divergence, + unbalanced_type=divergence, + method=unbalanced_solver, + max_iter=max_iter_ot, + tol=tol_ot, + verbose=False + ) + + # initialize log + if log: + dict_log = {"backend": nx, + "error": []} + + for idx in range(max_iter): + pi_samp_prev = nx.copy(pi_samp) + + # Update feature coupling + mass = nx.sum(pi_samp) + uot_cost = local_cost_feat(pi=pi_samp) + + if divergence == "kl": + new_rho = (rho_x * mass, rho_y * mass) + new_eps = mass * eps_feat if reg_type == "joint" else eps_feat + new_wx, new_wy, new_wxy = wx_feat, wy_feat, wxy_feat + else: # divergence == "l2" + new_w, new_rho, new_eps = parameters_uot_l2_feat(pi_feat) + new_wx, new_wy, new_wxy = new_w + + res = solver(M=uot_cost, a=new_wx, b=new_wy, + reg=new_eps, c=new_wxy, unbalanced=new_rho, + plan_init=pi_feat, potentials_init=duals_feat) + pi_feat, duals_feat = res.plan, res.potentials + + if rescale_plan: + pi_feat = nx.sqrt(mass / nx.sum(pi_feat)) * pi_feat + + # Update sample coupling + mass = nx.sum(pi_feat) + uot_cost = local_cost_samp(pi=pi_feat) + + if divergence == "kl": + new_rho = (rho_x * mass, rho_y * mass) + new_eps = mass * eps_feat if reg_type == "joint" else eps_feat + new_wx, new_wy, new_wxy = wx_samp, wy_samp, wxy_samp + else: # divergence == "l2" + new_w, new_rho, new_eps = parameters_uot_l2_samp(pi_samp) + new_wx, new_wy, new_wxy = new_w + + res = solver(M=uot_cost, a=new_wx, b=new_wy, + reg=new_eps, c=new_wxy, unbalanced=new_rho, + plan_init=pi_samp, potentials_init=duals_samp) + pi_samp, duals_samp = res.plan, res.potentials + + if rescale_plan: + pi_samp = nx.sqrt(mass / nx.sum(pi_samp)) * pi_samp # shape nx x ny + + # get L1 error + err = nx.sum(nx.abs(pi_samp - pi_samp_prev)) + if log: + dict_log["error"].append(err) + if verbose: + print('{:5d}|{:8e}|'.format(idx + 1, err)) + if err < tol: + break + + # sanity check + if nx.sum(nx.isnan(pi_samp)) > 0 or nx.sum(nx.isnan(pi_feat)) > 0: + raise (ValueError("There is NaN in coupling. \ + Adjust the relaxation or regularization parameters.")) + + if log: + linear_cost, ucoot_cost = fused_unbalanced_across_spaces_cost( + M_linear=(M_samp, M_feat), + data=(X_sqr, Y_sqr, X, Y), + tuple_pxy_samp=(wx_samp, wy_samp, wxy_samp), + tuple_pxy_feat=(wx_feat, wy_feat, wxy_feat), + pi_samp=pi_samp, pi_feat=pi_feat, + hyperparams=(rho_x, rho_y, eps_samp, eps_feat), + divergence=divergence, + reg_type=reg_type, + nx=nx + ) + + dict_log["duals_sample"] = duals_samp + dict_log["duals_feature"] = duals_feat + dict_log["linear_cost"] = linear_cost + dict_log["ucoot_cost"] = ucoot_cost + + return pi_samp, pi_feat, dict_log + + else: + return pi_samp, pi_feat + + +def unbalanced_co_optimal_transport( + X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, + reg_marginals=10, epsilon=0, divergence="kl", + unbalanced_solver="mm", alpha=0, M_samp=None, M_feat=None, + rescale_plan=True, init_pi=None, init_duals=None, + max_iter=100, tol=1e-7, max_iter_ot=500, tol_ot=1e-7, + log=False, verbose=False, **kwargs_solve): + + r"""Compute the unbalanced Co-Optimal Transport between two Euclidean point clouds + (represented as matrices whose rows are samples and columns are the features/dimensions). + + More precisely, this function returns the sample and feature transport plans between + :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and + :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`, + by solving the following problem using Block Coordinate Descent algorithm: + + .. math:: + \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}} &\quad \sum_{i,j,k,l} + (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} \\ + &+ \rho_s \mathbf{Div}(\mathbf{P}_{\# 1} \mathbf{Q}_{\# 1}^T | \mathbf{w}_{xs} \mathbf{w}_{ys}^T) + + \rho_f \mathbf{Div}(\mathbf{P}_{\# 2} \mathbf{Q}_{\# 2}^T | \mathbf{w}_{xf} \mathbf{w}_{yf}^T) \\ + &+ \alpha_s \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} + + \alpha_f \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l} \\ + &+ \varepsilon_s \mathbf{Div}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T) + + \varepsilon_f \mathbf{Div}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T) + + Where: + + - :math:`\mathbf{X}`: Source input (arbitrary-size) matrix + - :math:`\mathbf{Y}`: Target input (arbitrary-size) matrix + - :math:`\mathbf{M^{(s)}}`: Additional sample matrix + - :math:`\mathbf{M^{(f)}}`: Additional feature matrix + - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space + - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space + - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space + - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space + - :math:`\mathbf{Div}`: Either Kullback-Leibler divergence or half-squared L2 norm. + + .. note:: This function allows `epsilon` to be zero. In that case, `unbalanced_method` must be either "mm" or "lbfgsb". + + Parameters + ---------- + X : (n_sample_x, n_feature_x) array-like, float + Source input matrix. + Y : (n_sample_y, n_feature_y) array-like, float + Target input matrix. + wx_samp : (n_sample_x, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix X. + Uniform distribution by default. + wx_feat : (n_feature_x, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix X. + Uniform distribution by default. + wy_samp : (n_sample_y, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix Y. + Uniform distribution by default. + wy_feat : (n_feature_y, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix Y. + Uniform distribution by default. + reg_marginals: float or indexable object of length 1 or 2 + Marginal relaxation terms for sample and feature couplings. + If `reg_marginals is a scalar` or an indexable object of length 1, + then the same value is applied to both marginal relaxations. + epsilon : scalar or indexable object of length 2, float or int, optional (default = 0) + Regularization parameters for entropic approximation of sample and feature couplings. + Allow the case where `epsilon` contains 0. In that case, the MM solver is used by default + instead of Sinkhorn solver. If `epsilon` is scalar, then the same value is applied to + both regularization of sample and feature couplings. + divergence : string, optional (default = "kl") + + - If `divergence` = "kl", then Div is the Kullback-Leibler divergence. + + - If `divergence` = "l2", then Div is the half squared Euclidean norm. + unbalanced_solver : string, optional (default = "sinkhorn") + Solver for the unbalanced OT subroutine. + + - If `divergence` = "kl", then `unbalanced_solver` can be: "sinkhorn", "sinkhorn_log", "mm", "lbfgsb" + + - If `divergence` = "l2", then `unbalanced_solver` can be "mm", "lbfgsb" + alpha : scalar or indexable object of length 2, float or int, optional (default = 0) + Coeffficient parameter of linear terms with respect to the sample and feature couplings. + If alpha is scalar, then the same alpha is applied to both linear terms. + M_samp : (n_sample_x, n_sample_y), float, optional (default = None) + Sample matrix associated to the Wasserstein linear term on sample coupling. + M_feat : (n_feature_x, n_feature_y), float, optional (default = None) + Feature matrix associated to the Wasserstein linear term on feature coupling. + rescale_plan : boolean, optional (default = True) + If True, then rescale the sample and feature transport plans within each BCD iteration, + so that they always have equal mass. + init_pi : tuple of two matrices of size (n_sample_x, n_sample_y) and + (n_feature_x, n_feature_y), optional (default = None). + Initialization of sample and feature couplings. + Uniform distributions by default. + init_duals : tuple of two tuples ((n_sample_x, ), (n_sample_y, )) and ((n_feature_x, ), (n_feature_y, )), optional (default = None). + Initialization of sample and feature dual vectors + if using Sinkhorn algorithm. Zero vectors by default. + max_iter : int, optional (default = 100) + Number of Block Coordinate Descent (BCD) iterations. + tol : float, optional (default = 1e-7) + Tolerance of BCD scheme. If the L1-norm between the current and previous + sample couplings is under this threshold, then stop BCD scheme. + max_iter_ot : int, optional (default = 100) + Number of iterations to solve each of the + two unbalanced optimal transport problems in each BCD iteration. + tol_ot : float, optional (default = 1e-7) + Tolerance of unbalanced solver for each of the + two unbalanced optimal transport problems in each BCD iteration. + log : bool, optional (default = False) + If True then the cost and four dual vectors, including + two from sample and two from feature couplings, are recorded. + verbose : bool, optional (default = False) + If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration. + + Returns + ------- + pi_samp : (n_sample_x, n_sample_y) array-like, float + Sample coupling matrix. + pi_feat : (n_feature_x, n_feature_y) array-like, float + Feature coupling matrix. + log : dictionary, optional + Returned if `log` is True. The keys are: + + error : array-like, float + list of L1 norms between the current and previous sample coupling. + duals_sample : (n_sample_x, n_sample_y)-tuple, float + Pair of dual vectors when solving OT problem w.r.t the sample coupling. + duals_feature : (n_feature_x, n_feature_y)-tuple, float + Pair of dual vectors when solving OT problem w.r.t the feature coupling. + linear : float + Linear part of the cost. + ucoot : float + Total cost. + + References + ---------- + .. [71] Tran, H., Janati, H., Courty, N., Flamary, R., Redko, I., Demetci, P., & Singh, R. + Unbalanced Co-Optimal Transport. AAAI Conference on Artificial Intelligence, 2023. + """ + + return fused_unbalanced_across_spaces_divergence( + X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, + wy_samp=wy_samp, wy_feat=wy_feat, reg_marginals=reg_marginals, + epsilon=epsilon, reg_type="independent", + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M_samp=M_samp, M_feat=M_feat, rescale_plan=rescale_plan, + init_pi=init_pi, init_duals=init_duals, max_iter=max_iter, tol=tol, + max_iter_ot=max_iter_ot, tol_ot=tol_ot, log=log, verbose=verbose, + **kwargs_solve) + + +def unbalanced_co_optimal_transport2( + X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, + reg_marginals=10, epsilon=0, divergence="kl", + unbalanced_solver="sinkhorn", alpha=0, M_samp=None, M_feat=None, + rescale_plan=True, init_pi=None, init_duals=None, + max_iter=100, tol=1e-7, max_iter_ot=500, tol_ot=1e-7, + log=False, verbose=False, **kwargs_solve): + + r"""Compute the unbalanced Co-Optimal Transport between two Euclidean point clouds + (represented as matrices whose rows are samples and columns are the features/dimensions). + + More precisely, this function returns the unbalanced Co-Optimal Transport cost between + :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and + :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`, + by solving the following problem using Block Coordinate Descent algorithm: + + .. math:: + \mathop{\min}_{\mathbf{P}, \mathbf{Q}} &\quad \sum_{i,j,k,l} + (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} \\ + &+ \rho_s \mathbf{Div}(\mathbf{P}_{\# 1} \mathbf{Q}_{\# 1}^T | \mathbf{w}_{xs} \mathbf{w}_{ys}^T) + + \rho_f \mathbf{Div}(\mathbf{P}_{\# 2} \mathbf{Q}_{\# 2}^T | \mathbf{w}_{xf} \mathbf{w}_{yf}^T) \\ + &+ \alpha_s \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} + + \alpha_f \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l} \\ + &+ \varepsilon_s \mathbf{Div}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T) + + \varepsilon_f \mathbf{Div}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T) + + Where: + + - :math:`\mathbf{X}`: Source input (arbitrary-size) matrix + - :math:`\mathbf{Y}`: Target input (arbitrary-size) matrix + - :math:`\mathbf{M^{(s)}}`: Additional sample matrix + - :math:`\mathbf{M^{(f)}}`: Additional feature matrix + - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space + - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space + - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space + - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space + - :math:`\mathbf{Div}`: Either Kullback-Leibler divergence or half-squared L2 norm. + + .. note:: This function allows `epsilon` to be zero. In that case, `unbalanced_method` must be either "mm" or "lbfgsb". + Also the computation of gradients is only supported for KL divergence. The case of half squared-L2 norm uses those of KL divergence. + + Parameters + ---------- + X : (n_sample_x, n_feature_x) array-like, float + Source input matrix. + Y : (n_sample_y, n_feature_y) array-like, float + Target input matrix. + wx_samp : (n_sample_x, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix X. + Uniform distribution by default. + wx_feat : (n_feature_x, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix X. + Uniform distribution by default. + wy_samp : (n_sample_y, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix Y. + Uniform distribution by default. + wy_feat : (n_feature_y, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix Y. + Uniform distribution by default. + reg_marginals: float or indexable object of length 1 or 2 + Marginal relaxation terms for sample and feature couplings. + If `reg_marginals` is a scalar or an indexable object of length 1, + then the same value is applied to both marginal relaxations. + epsilon : scalar or indexable object of length 2, float or int, optional (default = 0) + Regularization parameters for entropic approximation of sample and feature couplings. + Allow the case where `epsilon` contains 0. In that case, the MM solver is used by default + instead of Sinkhorn solver. If `epsilon` is scalar, then the same value is applied to + both regularization of sample and feature couplings. + divergence : string, optional (default = "kl") + + - If `divergence` = "kl", then Div is the Kullback-Leibler divergence. + + - If `divergence` = "l2", then Div is the half squared Euclidean norm. + unbalanced_solver : string, optional (default = "sinkhorn") + Solver for the unbalanced OT subroutine. + + - If `divergence` = "kl", then `unbalanced_solver` can be: "sinkhorn", "sinkhorn_log", "mm", "lbfgsb" + + - If `divergence` = "l2", then `unbalanced_solver` can be "mm", "lbfgsb" + alpha : scalar or indexable object of length 2, float or int, optional (default = 0) + Coeffficient parameter of linear terms with respect to the sample and feature couplings. + If alpha is scalar, then the same alpha is applied to both linear terms. + M_samp : (n_sample_x, n_sample_y), float, optional (default = None) + Sample matrix associated to the Wasserstein linear term on sample coupling. + M_feat : (n_feature_x, n_feature_y), float, optional (default = None) + Feature matrix associated to the Wasserstein linear term on feature coupling. + rescale_plan : boolean, optional (default = True) + If True, then rescale the transport plans in each BCD iteration, + so that they always have equal mass. + init_pi : tuple of two matrices of size (n_sample_x, n_sample_y) and + (n_feature_x, n_feature_y), optional (default = None). + Initialization of sample and feature couplings. + Uniform distributions by default. + init_duals : tuple of two tuples ((n_sample_x, ), (n_sample_y, )) and ((n_feature_x, ), (n_feature_y, )), optional (default = None). + Initialization of sample and feature dual vectors + if using Sinkhorn algorithm. Zero vectors by default. + max_iter : int, optional (default = 100) + Number of Block Coordinate Descent (BCD) iterations. + tol : float, optional (default = 1e-7) + Tolerance of BCD scheme. If the L1-norm between the current and previous + sample couplings is under this threshold, then stop BCD scheme. + max_iter_ot : int, optional (default = 100) + Number of iterations to solve each of the + two unbalanced optimal transport problems in each BCD iteration. + tol_ot : float, optional (default = 1e-7) + Tolerance of unbalanced solver for each of the + two unbalanced optimal transport problems in each BCD iteration. + log : bool, optional (default = False) + If True then the cost and four dual vectors, including + two from sample and two from feature couplings, are recorded. + verbose : bool, optional (default = False) + If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration. + + Returns + ------- + ucoot : float + UCOOT cost. + log : dictionary, optional + Returned if `log` is True. The keys are: + + error : array-like, float + list of L1 norms between the current and previous sample coupling. + duals_sample : (n_sample_x, n_sample_y)-tuple, float + Pair of dual vectors when solving OT problem w.r.t the sample coupling. + duals_feature : (n_feature_x, n_feature_y)-tuple, float + Pair of dual vectors when solving OT problem w.r.t the feature coupling. + linear : float + Linear part of UCOOT cost. + ucoot : float + UCOOT cost. + backend + The proper backend for all input arrays + + References + ---------- + .. [71] Tran, H., Janati, H., Courty, N., Flamary, R., Redko, I., Demetci, P., & Singh, R. + Unbalanced Co-Optimal Transport. AAAI Conference on Artificial Intelligence, 2023. + """ + + if divergence != "kl": + warnings.warn("The computation of gradients is only supported for KL divergence, not \ + for {} divergence".format(divergence)) + + pi_samp, pi_feat, log_ucoot = unbalanced_co_optimal_transport( + X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, wy_samp=wy_samp, wy_feat=wy_feat, + reg_marginals=reg_marginals, epsilon=epsilon, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, M_samp=M_samp, M_feat=M_feat, + rescale_plan=rescale_plan, init_pi=init_pi, init_duals=init_duals, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=True, verbose=verbose, **kwargs_solve) + + nx = log_ucoot["backend"] + + nx_samp, nx_feat = X.shape + ny_samp, ny_feat = Y.shape + + # measures on rows and columns + if wx_samp is None: + wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp + if wx_feat is None: + wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat + if wy_samp is None: + wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp + if wy_feat is None: + wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat + + # extract parameters + rho_x, rho_y = get_parameter_pair(reg_marginals) + eps_samp, eps_feat = get_parameter_pair(epsilon) + + # calculate marginals + pi1_samp, pi2_samp = nx.sum(pi_samp, 1), nx.sum(pi_samp, 0) + pi1_feat, pi2_feat = nx.sum(pi_feat, 1), nx.sum(pi_feat, 0) + m_samp, m_feat = nx.sum(pi1_samp), nx.sum(pi1_feat) + m_wx_feat, m_wx_samp = nx.sum(wx_feat), nx.sum(wx_samp) + m_wy_feat, m_wy_samp = nx.sum(wy_feat), nx.sum(wy_samp) + + # calculate subgradients + gradX = 2 * X * (pi1_samp[:, None] * pi1_feat[None, :]) - \ + 2 * nx.dot(nx.dot(pi_samp, Y), pi_feat.T) # shape (nx_samp, nx_feat) + gradY = 2 * Y * (pi2_samp[:, None] * pi2_feat[None, :]) - \ + 2 * nx.dot(nx.dot(pi_samp.T, X), pi_feat) # shape (ny_samp, ny_feat) + + grad_wx_samp = rho_x * (m_wx_feat - m_feat * pi1_samp / wx_samp) + \ + eps_samp * (m_wy_samp - pi1_samp / wx_samp) + grad_wx_feat = rho_x * (m_wx_samp - m_samp * pi1_feat / wx_feat) + \ + eps_feat * (m_wy_feat - pi1_feat / wx_feat) + grad_wy_samp = rho_y * (m_wy_feat - m_feat * pi2_samp / wy_samp) + \ + eps_samp * (m_wx_samp - pi2_samp / wy_samp) + grad_wy_feat = rho_y * (m_wy_samp - m_samp * pi2_feat / wy_feat) + \ + eps_feat * (m_wx_feat - pi2_feat / wy_feat) + + # set gradients + ucoot = log_ucoot["ucoot_cost"] + ucoot = nx.set_gradients(ucoot, + (X, Y, wx_samp, wx_feat, wy_samp, wy_feat), + (gradX, gradY, grad_wx_samp, grad_wx_feat, grad_wy_samp, grad_wy_feat) + ) + + if log: + return ucoot, log_ucoot + + else: + return ucoot + + +def fused_unbalanced_gromov_wasserstein( + Cx, Cy, wx=None, wy=None, reg_marginals=10, epsilon=0, + divergence="kl", unbalanced_solver="mm", alpha=0, + M=None, init_duals=None, init_pi=None, max_iter=100, + tol=1e-7, max_iter_ot=500, tol_ot=1e-7, + log=False, verbose=False, **kwargs_solve): + + r"""Compute the lower bound of the fused unbalanced Gromov-Wasserstein (FUGW) between two similarity matrices. + In practice, this lower bound is used interchangeably with the true FUGW. + + More precisely, this function returns the transport plan between + :math:`(\mathbf{C^X}, \mathbf{w_X})` and :math:`(\mathbf{C^Y}, \mathbf{w_Y})`, + by solving the following problem using Block Coordinate Descent algorithm: + + .. math:: + \mathop{\arg \min}_{\substack{\mathbf{P}, \mathbf{Q}: \\ mass(P) = mass(Q)}} + &\quad \sum_{i,j,k,l} (\mathbf{C^X}_{i,k} - \mathbf{C^Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} + + \frac{\alpha}{2} \sum_{i,j} (\mathbf{P}_{i,j} + \mathbf{Q}_{i,j}) \mathbf{M}_{i, j} \\ + &+ \rho_1 \mathbf{Div}(\mathbf{P}_{\# 1} \mathbf{Q}_{\# 1}^T | \mathbf{w_X} \mathbf{w_X}^T) + + \rho_2 \mathbf{Div}(\mathbf{P}_{\# 2} \mathbf{Q}_{\# 2}^T | \mathbf{w_Y} \mathbf{w_Y}^T) \\ + &+ \varepsilon \mathbf{Div}(\mathbf{P} \otimes \mathbf{Q} | (\mathbf{w_X} \mathbf{w_Y}^T) \otimes (\mathbf{w_X} \mathbf{w_Y}^T)) + + Where: + + - :math:`\mathbf{C^X}`: Source similarity matrix + - :math:`\mathbf{C^Y}`: Target similarity matrix + - :math:`\mathbf{M}`: Sample matrix corresponding to the Wasserstein term + - :math:`\mathbf{w_X}`: Distribution of the samples in the source space + - :math:`\mathbf{w_Y}`: Distribution of the samples in the target space + - :math:`\mathbf{Div}`: Either Kullback-Leibler divergence or half-squared L2 norm. + + .. note:: This function allows epsilon to be zero. In that case, `unbalanced_method` must be either "mm" or "lbfgsb". + + Parameters + ---------- + Cx : (n_sample_x, n_feature_x) array-like, float + Source similarity matrix. + Cy : (n_sample_y, n_feature_y) array-like, float + Target similarity matrix. + wx : (n_sample_x, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix Cx. + Uniform distribution by default. + wy : (n_sample_y, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix Cy. + Uniform distribution by default. + reg_marginals: float or indexable object of length 1 or 2 + Marginal relaxation terms for sample and feature couplings. + If `reg_marginals` is a scalar or an indexable object of length 1, + then the same value is applied to both marginal relaxations. + epsilon : scalar, float or int, optional (default = 0) + Regularization parameters for entropic approximation of sample and feature couplings. + Allow the case where `epsilon` contains 0. In that case, the MM solver is used by default + instead of Sinkhorn solver. If `epsilon` is scalar, then the same value is applied to + both regularization of sample and feature couplings. + divergence : string, optional (default = "kl") + + - If `divergence` = "kl", then Div is the Kullback-Leibler divergence. + + - If `divergence` = "l2", then Div is the half squared Euclidean norm. + unbalanced_solver : string, optional (default = "sinkhorn") + Solver for the unbalanced OT subroutine. + + - If `divergence` = "kl", then `unbalanced_solver` can be: "sinkhorn", "sinkhorn_log", "mm", "lbfgsb" + + - If `divergence` = "l2", then `unbalanced_solver` can be "mm", "lbfgsb" + alpha : scalar, float or int, optional (default = 0) + Coeffficient parameter of linear terms with respect to the sample and feature couplings. + If alpha is scalar, then the same alpha is applied to both linear terms. + M : (n_sample_x, n_sample_y), float, optional (default = None) + Sample matrix associated to the Wasserstein linear term on sample coupling. + init_pi :(n_sample_x, n_sample_y) array-like, optional (default = None) + Initialization of sample coupling. By default = :math:`w_X w_Y^T`. + init_duals : tuple of vectors ((n_sample_x, ), (n_sample_y, )), optional (default = None). + Initialization of sample and feature dual vectors + if using Sinkhorn algorithm. Zero vectors by default. + max_iter : int, optional (default = 100) + Number of Block Coordinate Descent (BCD) iterations. + tol : float, optional (default = 1e-7) + Tolerance of BCD scheme. If the L1-norm between the current and previous + sample couplings is under this threshold, then stop BCD scheme. + max_iter_ot : int, optional (default = 100) + Number of iterations to solve each of the + two unbalanced optimal transport problems in each BCD iteration. + tol_ot : float, optional (default = 1e-7) + Tolerance of unbalanced solver for each of the + two unbalanced optimal transport problems in each BCD iteration. + log : bool, optional (default = False) + If True then the cost and four dual vectors, including + two from sample and two from feature couplings, are recorded. + verbose : bool, optional (default = False) + If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration. + + Returns + ------- + pi_samp : (n_sample_x, n_sample_y) array-like, float + Sample coupling matrix. + In practice, we use this matrix as solution of FUGW. + pi_samp2 : (n_sample_x, n_sample_y) array-like, float + Second sample coupling matrix. + In practice, we usually ignore this output. + log : dictionary, optional + Returned if `log` is True. The keys are: + + error : array-like, float + list of L1 norms between the current and previous sample couplings. + duals : (n_sample_x, n_sample_y)-tuple, float + Pair of dual vectors when solving OT problem w.r.t the sample coupling. + linear : float + Linear part of FUGW cost. + fugw_cost : float + Total FUGW cost. + backend + The proper backend for all input arrays + + References + ---------- + .. [70] Thual, A., Tran, H., Zemskova, T., Courty, N., Flamary, R., Dehaene, S., & Thirion, B. + Aligning individual brains with Fused Unbalanced Gromov-Wasserstein. + Advances in Neural Information Systems, 35 (2022). + + .. [72] Thibault Séjourné, François-Xavier Vialard, & Gabriel Peyré. + The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation. + Neural Information Processing Systems, 34 (2021). + """ + + alpha = (alpha / 2, alpha / 2) + + pi_samp, pi_feat, dict_log = fused_unbalanced_across_spaces_divergence( + X=Cx, Y=Cy, wx_samp=wx, wx_feat=wx, wy_samp=wy, wy_feat=wy, + reg_marginals=reg_marginals, epsilon=epsilon, reg_type="joint", + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M_samp=M, M_feat=M, rescale_plan=True, + init_pi=(init_pi, init_pi), + init_duals=(init_duals, init_duals), max_iter=max_iter, tol=tol, + max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=True, verbose=verbose, **kwargs_solve + ) + + if log: + log_fugw = {"error": dict_log["error"], + "duals": dict_log["duals_sample"], + "linear_cost": dict_log["linear_cost"], + "fugw_cost": dict_log["ucoot_cost"], + "backend": dict_log["backend"]} + + return pi_samp, pi_feat, log_fugw + + else: + return pi_samp, pi_feat + + +def fused_unbalanced_gromov_wasserstein2( + Cx, Cy, wx=None, wy=None, reg_marginals=10, epsilon=0, + divergence="kl", unbalanced_solver="mm", alpha=0, + M=None, init_duals=None, init_pi=None, max_iter=100, + tol=1e-7, max_iter_ot=500, tol_ot=1e-7, + log=False, verbose=False, **kwargs_solve): + + r"""Compute the lower bound of the fused unbalanced Gromov-Wasserstein (FUGW) between two similarity matrices. + In practice, this lower bound is used interchangeably with the true FUGW. + + More precisely, this function returns the lower bound of the fused unbalanced Gromov-Wasserstein cost between + :math:`(\mathbf{C^X}, \mathbf{w_X})` and :math:`(\mathbf{C^Y}, \mathbf{w_Y})`, + by solving the following problem using Block Coordinate Descent algorithm: + + .. math:: + \mathop{\min}_{\substack{\mathbf{P}, \mathbf{Q}: \\ mass(P) = mass(Q)}} + &\quad \sum_{i,j,k,l} (\mathbf{C^X}_{i,k} - \mathbf{C^Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} + + \frac{\alpha}{2} \sum_{i,j} (\mathbf{P}_{i,j} + \mathbf{Q}_{i,j}) \mathbf{M}_{i, j} \\ + &+ \rho_1 \mathbf{Div}(\mathbf{P}_{\# 1} \mathbf{Q}_{\# 1}^T | \mathbf{w_X} \mathbf{w_X}^T) + + \rho_2 \mathbf{Div}(\mathbf{P}_{\# 2} \mathbf{Q}_{\# 2}^T | \mathbf{w_Y} \mathbf{w_Y}^T) \\ + &+ \varepsilon \mathbf{Div}(\mathbf{P} \otimes \mathbf{Q} | (\mathbf{w_X} \mathbf{w_Y}^T) \otimes (\mathbf{w_X} \mathbf{w_Y}^T)) + + Where: + + - :math:`\mathbf{C^X}`: Source similarity matrix + - :math:`\mathbf{C^Y}`: Target similarity matrix + - :math:`\mathbf{M}`: Sample matrix corresponding to the Wasserstein term + - :math:`\mathbf{w_X}`: Distribution of the samples in the source space + - :math:`\mathbf{w_Y}`: Distribution of the samples in the target space + - :math:`\mathbf{Div}`: Either Kullback-Leibler divergence or half-squared L2 norm. + + .. note:: This function allows `epsilon` to be zero. In that case, unbalanced_method must be either "mm" or "lbfgsb". + Also the computation of gradients is only supported for KL divergence, but not for half squared-L2 norm. In case of half squared-L2 norm, the calculation of KL divergence will be used. + + Parameters + ---------- + Cx : (n_sample_x, n_feature_x) array-like, float + Source similarity matrix. + Cy : (n_sample_y, n_feature_y) array-like, float + Target similarity matrix. + wx : (n_sample_x, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix Cx. + Uniform distribution by default. + wy : (n_sample_y, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix Cy. + Uniform distribution by default. + reg_marginals: float or indexable object of length 1 or 2 + Marginal relaxation terms for sample and feature couplings. + If `reg_marginals` is a scalar or an indexable object of length 1, + then the same value is applied to both marginal relaxations. + epsilon : scalar, float or int, optional (default = 0) + Regularization parameters for entropic approximation of sample and feature couplings. + Allow the case where `epsilon` contains 0. In that case, the MM solver is used by default + instead of Sinkhorn solver. If `epsilon` is scalar, then the same value is applied to + both regularization of sample and feature couplings. + divergence : string, optional (default = "kl") + + - If `divergence` = "kl", then Div is the Kullback-Leibler divergence. + + - If `divergence` = "l2", then Div is the half squared Euclidean norm. + unbalanced_solver : string, optional (default = "sinkhorn") + Solver for the unbalanced OT subroutine. + + - If `divergence` = "kl", then `unbalanced_solver` can be: "sinkhorn", "sinkhorn_log", "mm", "lbfgsb" + + - If `divergence` = "l2", then `unbalanced_solver` can be "mm", "lbfgsb" + alpha : scalar, float or int, optional (default = 0) + Coeffficient parameter of linear terms with respect to the sample and feature couplings. + If alpha is scalar, then the same alpha is applied to both linear terms. + M : (n_sample_x, n_sample_y), float, optional (default = None) + Sample matrix associated to the Wasserstein linear term on sample coupling. + init_pi :(n_sample_x, n_sample_y) array-like, optional (default = None) + Initialization of sample coupling. By default = :math:`w_X w_Y^T`. + init_duals : tuple of vectors ((n_sample_x, ), (n_sample_y, )), optional (default = None). + Initialization of sample and feature dual vectors + if using Sinkhorn algorithm. Zero vectors by default. + max_iter : int, optional (default = 100) + Number of Block Coordinate Descent (BCD) iterations. + tol : float, optional (default = 1e-7) + Tolerance of BCD scheme. If the L1-norm between the current and previous + sample couplings is under this threshold, then stop BCD scheme. + max_iter_ot : int, optional (default = 100) + Number of iterations to solve each of the + two unbalanced optimal transport problems in each BCD iteration. + tol_ot : float, optional (default = 1e-7) + Tolerance of unbalanced solver for each of the + two unbalanced optimal transport problems in each BCD iteration. + log : bool, optional (default = False) + If True then the cost and four dual vectors, including + two from sample and two from feature couplings, are recorded. + verbose : bool, optional (default = False) + If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration. + + Returns + ------- + fugw : float + Total FUGW cost + log : dictionary, optional + Returned if `log` is True. The keys are: + + error : array-like, float + list of L1 norms between the current and previous sample couplings. + duals : (n_sample_x, n_sample_y)-tuple, float + Pair of dual vectors when solving OT problem w.r.t the sample coupling. + linear : float + Linear part of FUGW cost. + fugw_cost : float + Total FUGW cost. + backend + The proper backend for all input arrays + + References + ---------- + .. [70] Thual, A., Tran, H., Zemskova, T., Courty, N., Flamary, R., Dehaene, S., & Thirion, B. + Aligning individual brains with Fused Unbalanced Gromov-Wasserstein. + Advances in Neural Information Systems, 35 (2022). + + .. [72] Thibault Séjourné, François-Xavier Vialard, & Gabriel Peyré. + The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation. + Neural Information Processing Systems, 34 (2021). + """ + + if divergence != "kl": + warnings.warn("The computation of gradients is only supported for KL divergence, \ + but not for {} divergence. The gradient of the KL case will be used.".format(divergence)) + + pi_samp, pi_feat, log_fugw = fused_unbalanced_gromov_wasserstein( + Cx=Cx, Cy=Cy, wx=wx, wy=wy, reg_marginals=reg_marginals, + epsilon=epsilon, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, M=M, + init_duals=init_duals, init_pi=init_pi, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, + tol_ot=tol_ot, log=True, verbose=verbose, **kwargs_solve + ) + + nx = log_fugw["backend"] + sx, sy = Cx.shape[0], Cy.shape[0] + + # measures on rows and columns + if wx is None: + wx = nx.ones(sx, type_as=Cx) / sx + if wy is None: + wy = nx.ones(sy, type_as=Cy) / sy + + # calculate marginals + pi1_samp, pi2_samp = nx.sum(pi_samp, 1), nx.sum(pi_samp, 0) + pi1_feat, pi2_feat = nx.sum(pi_feat, 1), nx.sum(pi_feat, 0) + m_samp, m_feat = nx.sum(pi1_samp), nx.sum(pi1_feat) + m_wx, m_wy = nx.sum(wx), nx.sum(wy) + + # calculate subgradients + gradX = 2 * Cx * (pi1_samp[:, None] * pi1_feat[None, :]) - \ + 2 * nx.dot(nx.dot(pi_samp, Cy), pi_feat.T) # shape (nx_samp, nx_feat) + gradY = 2 * Cy * (pi2_samp[:, None] * pi2_feat[None, :]) - \ + 2 * nx.dot(nx.dot(pi_samp.T, Cx), pi_feat) # shape (ny_samp, ny_feat) + + gradM = alpha / 2 * (pi_samp + pi_feat) + + rho_x, rho_y = get_parameter_pair(reg_marginals) + grad_wx = 2 * m_wx * (rho_x + epsilon * m_wy**2) - \ + (rho_x + epsilon) * (m_feat * pi1_samp + m_samp * pi1_feat) / wx + grad_wy = 2 * m_wy * (rho_y + epsilon * m_wx**2) - \ + (rho_y + epsilon) * (m_feat * pi2_samp + m_samp * pi2_feat) / wy + + # set gradients + fugw = log_fugw["fugw_cost"] + fugw = nx.set_gradients(fugw, (Cx, Cy, M, wx, wy), + (gradX, gradY, gradM, grad_wx, grad_wy)) + + if log: + return fugw, log_fugw + + else: + return fugw diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 5c465cba8..fb07bb1ef 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -8,6 +8,7 @@ # Rémi Flamary # Titouan Vayer # Cédric Vincent-Cuaz +# Quang Huy Tran # # License: MIT License @@ -797,3 +798,358 @@ def update_barycenter_feature( inv_p = 1. / p_sum return sum(list_features) * inv_p[:, None] + + +############################################################################ +# Methods related to fused unbalanced GW and unbalanced Co-Optimal Transport. +############################################################################ + +def div_to_product(pi, a, b, pi1=None, pi2=None, divergence="kl", mass=True, nx=None): + r"""Fast computation of the Bregman divergence between an arbitrary measure and a product measure. + Only support for Kullback-Leibler and half-squared L2 divergences. + + - For half-squared L2 divergence: + + .. math:: + \frac{1}{2} || \pi - a \otimes b ||^2 + = \frac{1}{2} \Big[ \sum_{i, j} \pi_{ij}^2 + (\sum_i a_i^2) ( \sum_j b_j^2) - 2 \sum_{i, j} a_i \pi_{ij} b_j \Big] + + - For Kullback-Leibler divergence: + + .. math:: + KL(\pi | a \otimes b) + = \langle \pi, \log \pi \rangle - \langle \pi_1, \log a \rangle + - \langle \pi_2, \log b \rangle - m(\pi) + m(a) m(b) + + where : + + - :math:`\pi` is the (`dim_a`, `dim_b`) transport plan + - :math:`\pi_1` and :math:`\pi_2` are the marginal distributions + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`m` denotes the mass of the measure + + Parameters + ---------- + pi : array-like (dim_a, dim_b) + Transport plan + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + pi1 : array-like (dim_a,), optional (default = None) + Marginal distribution with respect to the first dimension of the transport plan + Only used in case of Kullback-Leibler divergence. + pi2 : array-like (dim_a,), optional (default = None) + Marginal distribution with respect to the second dimension of the transport plan + Only used in case of Kullback-Leibler divergence. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + mass : bool, optional. Default is False. + Only used in case of Kullback-Leibler divergence. + If False, calculate the relative entropy. + If True, calculate the Kullback-Leibler divergence. + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ------- + Bregman divergence between an arbitrary measure and a product measure. + """ + + arr = [pi, a, b, pi1, pi2] + + if nx is None: + nx = get_backend(*arr, pi1, pi2) + + if divergence == "kl": + + if pi1 is None: + pi1 = nx.sum(pi, 1) + if pi2 is None: + pi2 = nx.sum(pi, 0) + + res = nx.sum(pi * nx.log(pi + 1.0 * (pi == 0))) \ + - nx.sum(pi1 * nx.log(a)) - nx.sum(pi2 * nx.log(b)) + if mass: + res = res - nx.sum(pi1) + nx.sum(a) * nx.sum(b) + + elif divergence == "l2": + res = (nx.sum(pi**2) + nx.sum(a**2) * nx.sum(b**2) + - 2 * nx.dot(a, nx.dot(pi, b))) / 2 + + return res + + +def div_between_product(mu, nu, alpha, beta, divergence, nx=None): + r"""Fast computation of the Bregman divergence between two product measures. + Only support for Kullback-Leibler and half-squared L2 divergences. + + For half-squared L2 divergence: + + .. math:: + \frac{1}{2} || \mu \otimes \nu, \alpha \otimes \beta ||^2 + = \frac{1}{2} \Big[ ||\alpha||^2 ||\beta||^2 + ||\mu||^2 ||\nu||^2 - 2 \langle \alpha, \mu \rangle \langle \beta, \nu \rangle \Big] + + For Kullback-Leibler divergence: + + .. math:: + KL(\mu \otimes \nu, \alpha \otimes \beta) + = m(\mu) * KL(\nu, \beta) + m(\nu) * KL(\mu, \alpha) + (m(\mu) - m(\alpha)) * (m(\nu) - m(\beta)) + + where: + + - :math:`\mu` and :math:`\alpha` are two measures having the same shape. + - :math:`\nu` and :math:`\beta` are two measures having the same shape. + - :math:`m` denotes the mass of the measure + + Parameters + ---------- + mu : array-like + vector or matrix + nu : array-like + vector or matrix + alpha : array-like + vector or matrix with the same shape as `\mu` + beta : array-like + vector or matrix with the same shape as `\nu` + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ---------- + Bregman divergence between two product measures. + """ + + if nx is None: + nx = get_backend(mu, nu, alpha, beta) + + if divergence == "kl": + m_mu, m_nu = nx.sum(mu), nx.sum(nu) + m_alpha, m_beta = nx.sum(alpha), nx.sum(beta) + const = (m_mu - m_alpha) * (m_nu - m_beta) + res = m_nu * nx.kl_div(mu, alpha, mass=True) + m_mu * nx.kl_div(nu, beta, mass=True) + const + + elif divergence == "l2": + res = (nx.sum(alpha**2) * nx.sum(beta**2) - 2 * nx.sum(alpha * mu) * nx.sum(beta * nu) + + nx.sum(mu**2) * nx.sum(nu**2)) / 2 + + return res + + +# Support functions for BCD schemes +def uot_cost_matrix(data, pi, tuple_p, hyperparams, divergence, reg_type, nx=None): + r"""The Block Coordinate Descent algorithm for FUGW and UCOOT + requires solving an UOT problem in each iteration. + In particular, we need to specify the following inputs: + + - Cost matrix + + - Hyperparameters (marginal-relaxations and regularization) + + - Reference measures in the marginal-relaxation and regularization terms + + This method returns the cost matrix. + The method :any:`ot.gromov.uot_parameters_and_measures` returns the rest of the inputs. + + Parameters + ---------- + data : tuple of arrays + vector or matrix + pi : array-like + vector or matrix + tuple_p : tuple of arrays + Tuple of reference measures in the marginal-relaxation terms + w.r.t the (either sample or feature) coupling + hyperparams : tuple of floats + Hyperparameters of marginal-relaxation and regularization terms + in the fused unbalanced across-domain divergence + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + reg_type : string, + Type of regularization term in the fused unbalanced across-domain divergence + + - `reg_type = "joint"` corresponds to FUGW + + - `reg_type = "independent"` corresponds to UCOOT + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ---------- + Cost matrix of the UOT subroutine for UCOOT and FUGW + """ + + X_sqr, Y_sqr, X, Y, M = data + rho_x, rho_y, eps = hyperparams + a, b = tuple_p + + if nx is None: + nx = get_backend(X, Y, a, b) + + pi1, pi2 = nx.sum(pi, 1), nx.sum(pi, 0) + A, B = nx.dot(X_sqr, pi1), nx.dot(Y_sqr, pi2) + uot_cost = A[:, None] + B[None, :] - 2 * nx.dot(nx.dot(X, pi), Y.T) + if M is not None: + uot_cost = uot_cost + M + + if divergence == "kl": + if rho_x != float("inf") and rho_x != 0: + uot_cost = uot_cost + rho_x * nx.kl_div(pi1, a, mass=False) + if rho_y != float("inf") and rho_y != 0: + uot_cost = uot_cost + rho_y * nx.kl_div(pi2, b, mass=False) + if reg_type == "joint" and eps > 0: + uot_cost = uot_cost + eps * div_to_product(pi, a, b, pi1, pi2, + divergence, mass=False, nx=nx) + + return uot_cost + + +def uot_parameters_and_measures(pi, tuple_weights, hyperparams, reg_type, divergence, nx): + r"""The Block Coordinate Descent algorithm for FUGW and UCOOT + requires solving an UOT problem in each iteration. + In particular, we need to specify the following inputs: + + - Cost matrix + + - Hyperparameters (marginal-relaxations and regularization) + + - Reference measures in the marginal-relaxation and regularization terms + + The method :any:`ot.gromov.uot_cost_matrix` returns the cost matrix. + This method returns the rest of the inputs. + + Parameters + ---------- + pi : array-like + vector or matrix + tuple_weights : tuple of arrays + Tuple of reference measures in the marginal-relaxation and regularization terms + w.r.t the (either sample or feature) coupling + hyperparams : tuple of floats + Hyperparameters of marginal-relaxation and regularization terms + in the fused unbalanced across-domain divergence + reg_type : string, + Type of regularization term in the fused unbalanced across-domain divergence + + - `reg_type = "joint"` corresponds to FUGW + + - `reg_type = "independent"` corresponds to UCOOT + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ---------- + Tuple of hyperparameters and distributions (weights) + """ + + rho_x, rho_y, eps = hyperparams + wx, wy, wxy = tuple_weights + + if divergence == "l2": + pi1, pi2 = nx.sum(pi, 1), nx.sum(pi, 0) + l2_pi1, l2_pi2, l2_pi = nx.sum(pi1**2), nx.sum(pi2**2), nx.sum(pi**2) + + weighted_wx = wx * nx.sum(pi1 * wx) / l2_pi1 + weighted_wy = wy * nx.sum(pi2 * wy) / l2_pi2 + weighted_wxy = wxy * nx.sum(pi * wxy) / l2_pi if reg_type == "joint" else wxy + weighted_w = (weighted_wx, weighted_wy, weighted_wxy) + + new_rho = (rho_x * l2_pi1, rho_y * l2_pi2) + new_eps = eps * l2_pi if reg_type == "joint" else eps + + elif divergence == "kl": + mass = nx.sum(pi) + new_rho = (rho_x * mass, rho_y * mass) + new_eps = mass * eps if reg_type == "joint" else eps + weighted_w = tuple_weights + + return weighted_w, new_rho, new_eps + + +def fused_unbalanced_across_spaces_cost(M_linear, data, tuple_pxy_samp, tuple_pxy_feat, + pi_samp, pi_feat, hyperparams, divergence, reg_type, nx): + r"""Return the fused unbalanced across-space divergence between two spaces + + Parameters + ---------- + M_linear : tuple of arrays + Pair of cost matrices corresponding to the Wasserstein terms w.r.t sample and feature couplings + data : tuple of arrays + Tuple of input spaces represented as matrices + tuple_pxy_samp : tuple of arrays + Tuple of reference measures in the marginal-relaxation and regularization terms + w.r.t the sample coupling + tuple_pxy_feat : tuple of arrays + Tuple of reference measures in the marginal-relaxation and regularization terms + w.r.t the feature coupling + pi_samp : array-like + Sample coupling + pi_feat : array-like + Feature coupling + hyperparams : tuple of floats + Hyperparameters of marginal-relaxation and regularization terms + in the fused unbalanced across-domain divergence + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + reg_type : string, + Type of regularization term in the fused unbalanced across-domain divergence + + - `reg_type = "joint"` corresponds to FUGW + + - `reg_type = "independent"` corresponds to UCOOT + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ---------- + Fused unbalanced across-space divergence between two spaces + """ + + rho_x, rho_y, eps_samp, eps_feat = hyperparams + M_samp, M_feat = M_linear + px_samp, py_samp, pxy_samp = tuple_pxy_samp + px_feat, py_feat, pxy_feat = tuple_pxy_feat + X_sqr, Y_sqr, X, Y = data + + pi1_samp, pi2_samp = nx.sum(pi_samp, 1), nx.sum(pi_samp, 0) + pi1_feat, pi2_feat = nx.sum(pi_feat, 1), nx.sum(pi_feat, 0) + + A_sqr = nx.dot(nx.dot(X_sqr, pi1_feat), pi1_samp) + B_sqr = nx.dot(nx.dot(Y_sqr, pi2_feat), pi2_samp) + AB = nx.dot(nx.dot(X, pi_feat), Y.T) * pi_samp + linear_cost = A_sqr + B_sqr - 2 * nx.sum(AB) + + ucoot_cost = linear_cost + if M_samp is not None: + ucoot_cost = ucoot_cost + nx.sum(pi_samp * M_samp) + if M_feat is not None: + ucoot_cost = ucoot_cost + nx.sum(pi_feat * M_feat) + + if rho_x != float("inf") and rho_x != 0: + ucoot_cost = ucoot_cost + \ + rho_x * div_between_product(pi1_samp, pi1_feat, + px_samp, px_feat, divergence, nx) + if rho_y != float("inf") and rho_y != 0: + ucoot_cost = ucoot_cost + \ + rho_y * div_between_product(pi2_samp, pi2_feat, + py_samp, py_feat, divergence, nx) + + if reg_type == "joint" and eps_samp != 0: + div_cost = div_between_product(pi_samp, pi_feat, + pxy_samp, pxy_feat, divergence, nx) + ucoot_cost = ucoot_cost + eps_samp * div_cost + elif reg_type == "independent": + if eps_samp != 0: + div_samp = div_to_product(pi_samp, pi1_samp, pi2_samp, + px_samp, py_samp, divergence, mass=True, nx=nx) + ucoot_cost = ucoot_cost + eps_samp * div_samp + if eps_feat != 0: + div_feat = div_to_product(pi_feat, pi1_feat, pi2_feat, + px_feat, py_feat, divergence, mass=True, nx=nx) + ucoot_cost = ucoot_cost + eps_feat * div_feat + + return linear_cost, ucoot_cost diff --git a/ot/solvers.py b/ot/solvers.py index 4dccdc58c..ac2fcbb88 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -12,7 +12,6 @@ from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss -from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2, @@ -28,7 +27,7 @@ lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] -def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, +def solve(M, a=None, b=None, reg=None, c=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, potentials_init=None, tol=None, verbose=False, grad='autodiff'): r"""Solve the discrete optimal transport problem and return :any:`OTResult` object @@ -37,12 +36,12 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + - \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + - \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + \lambda_1 U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_2 U(\mathbf{T}^T\mathbf{1},\mathbf{b}) The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By default ``reg=None`` and there is no regularization. The unbalanced marginal - penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + penalization can be selected with `unbalanced` (:math:`(\lambda_1, \lambda_2)`) and `unbalanced_type`. By default ``unbalanced=None`` and the function solves the exact optimal transport problem (respecting the marginals). @@ -57,13 +56,24 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, reg : float, optional Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) + c : array-like (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. reg_type : str, optional Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL". a tuple of functions can be provided for general solver (see :any:`cg`). This is only used when ``reg!=None``. - unbalanced : float, optional - Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT) + unbalanced : float or indexable object of length 1 or 2 + Marginal relaxation term. + If it is a scalar or an indexable object of length 1, + then the same relaxation is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`unbalanced=float("inf")`. + For semi-relaxed case, use either + :math:`unbalanced=(float("inf"), scalar)` or + :math:`unbalanced=(scalar, float("inf"))`. + If unbalanced is an array, + it must have the same backend as input arrays `(a, b, M)`. unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL". @@ -173,7 +183,9 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, - **Unbalanced OT [41]** (when ``unbalanced!=None``): .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + + \lambda_1 U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_2 U(\mathbf{T}^T\mathbf{1},\mathbf{b}) can be solved with the following code: @@ -190,7 +202,9 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_1 U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_2 U(\mathbf{T}^T\mathbf{1},\mathbf{b}) can be solved with the following code: @@ -237,18 +251,18 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, """ # detect backend - arr = [M] - if a is not None: - arr.append(a) - if b is not None: - arr.append(b) - nx = get_backend(*arr) + nx = get_backend(M, a, b, c) # create uniform weights if not given if a is None: a = nx.ones(M.shape[0], type_as=M) / M.shape[0] if b is None: b = nx.ones(M.shape[1], type_as=M) / M.shape[1] + if c is None: + c = a[:, None] * b[None, :] + + if reg is None: + reg = 0 # default values for solutions potentials = None @@ -257,7 +271,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, plan = None status = None - if reg is None or reg == 0: # exact OT + if reg == 0: # exact OT if unbalanced is None: # Exact balanced OT @@ -280,32 +294,31 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, if tol is None: tol = 1e-12 - plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced, - div=unbalanced_type.lower(), numItermax=max_iter, + plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced, c=c, reg=reg, + div=unbalanced_type, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init) value_linear = log['cost'] - - if unbalanced_type.lower() == 'kl': - value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) - else: - err_a = nx.sum(plan, 1) - a - err_b = nx.sum(plan, 0) - b - value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2) + value = log['total_cost'] elif unbalanced_type.lower() == 'tv': if max_iter is None: - max_iter = 1000000 + max_iter = 1000 + if tol is None: + tol = 1e-12 + if isinstance(reg_type, str): + reg_type = reg_type.lower() - plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter) + plan, log = lbfgsb_unbalanced( + a, b, M, reg=reg, reg_m=unbalanced, c=c, reg_div=reg_type, + regm_div=unbalanced_type, numItermax=max_iter, + stopThr=tol, verbose=verbose, log=True, G0=plan_init + ) - value_linear = nx.sum(M * plan) - err_a = nx.sum(plan, 1) - a - err_b = nx.sum(plan, 0) - b - value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) + - nx.sum(nx.abs(err_b)))) + value_linear = log['cost'] + value = log['total_cost'] else: raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) @@ -316,12 +329,15 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, if isinstance(reg_type, tuple): # general solver + f, df = reg_type + if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 - plan, log = cg(a, b, M, reg=reg, f=reg_type[0], df=reg_type[1], numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init) + plan, log = cg(a, b, M, reg=reg, f=f, df=df, numItermax=max_iter, + stopThr=tol, log=True, verbose=verbose, G0=plan_init) value_linear = nx.sum(M * plan) value = log['loss'][-1] @@ -382,11 +398,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, if tol is None: tol = 1e-9 - plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) + plan, log = sinkhorn_knopp_unbalanced( + a, b, M, reg=reg, reg_m=unbalanced, + method=method, reg_type=reg_type, c=c, + warmstart=potentials_init, + numItermax=max_iter, stopThr=tol, + verbose=verbose, log=True + ) - value_linear = nx.sum(M * plan) - - value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) + value_linear = log['cost'] + value = log['total_cost'] potentials = (log['logu'], log['logv']) @@ -399,11 +420,14 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, if isinstance(reg_type, str): reg_type = reg_type.lower() - plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type, regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True, G0=plan_init) - - value_linear = nx.sum(M * plan) + plan, log = lbfgsb_unbalanced( + a, b, M, reg=reg, reg_m=unbalanced, c=c, reg_div=reg_type, + regm_div=unbalanced_type, numItermax=max_iter, + stopThr=tol, verbose=verbose, log=True, G0=plan_init + ) - value = log['cost'] + value_linear = log['cost'] + value = log['total_cost'] else: raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) @@ -909,7 +933,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, return res -def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", +def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, c=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, potentials_init=None, X_init=None, tol=None, verbose=False, @@ -946,11 +970,22 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t reg : float, optional Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) + c : array-like (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. reg_type : str, optional Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" - unbalanced : float, optional - Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT) + unbalanced : float or indexable object of length 1 or 2 + Marginal relaxation term. + If it is a scalar or an indexable object of length 1, + then the same relaxation is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`unbalanced=float("inf")`. + For semi-relaxed case, use either + :math:`unbalanced=(float("inf"), scalar)` or + :math:`unbalanced=(scalar, float("inf"))`. + If unbalanced is an array, + it must have the same backend as input arrays `(a, b, M)`. unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" lazy : bool, optional @@ -1249,7 +1284,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # compute cost matrix M and use solve function M = dist(X_a, X_b, metric) - res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad) + res = solve(M, a, b, reg, c, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad) return res diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index 062a472f9..a38c00a5e 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -286,7 +286,7 @@ def df(x): b = nx.ones(dim_b, type_as=M) / dim_b # convert to numpy - a, b, M = nx.to_numpy(a, b, M) + a, b, M, reg_m1, reg_m2, reg = nx.to_numpy(a, b, M, reg_m1, reg_m2, reg) G0 = a[:, None] * b[None, :] if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) diff --git a/ot/unbalanced/_sinkhorn.py b/ot/unbalanced/_sinkhorn.py index 37e85253b..ed7fd4d61 100644 --- a/ot/unbalanced/_sinkhorn.py +++ b/ot/unbalanced/_sinkhorn.py @@ -509,7 +509,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="kl", c=None, else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - if reg_type == "entropy": + if reg_type.lower() == "entropy": warnings.warn('If reg_type = entropy, then the matrix c is overwritten by the one matrix.') c = nx.ones((dim_a, dim_b), type_as=M) diff --git a/ot/utils.py b/ot/utils.py index 2ba541ea2..12910c479 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1043,10 +1043,7 @@ def potentials(self): This pair of arrays has the same shape, numerical type and properties as the input weights "a" and "b". """ - if self._potentials is not None: - return self._potentials - else: - raise NotImplementedError() + return self._potentials @property def potential_a(self): @@ -1054,7 +1051,7 @@ def potential_a(self): if self._potentials is not None: return self._potentials[0] else: - raise NotImplementedError() + return None @property def potential_b(self): @@ -1062,7 +1059,7 @@ def potential_b(self): if self._potentials is not None: return self._potentials[1] else: - raise NotImplementedError() + return None # Transport plan ------------------------------------------- @property @@ -1071,10 +1068,7 @@ def plan(self): # N.B.: We may catch out-of-memory errors and suggest # the use of lazy_plan or sparse_plan when appropriate. - if self._plan is not None: - return self._plan - else: - raise NotImplementedError() + return self._plan @property def sparse_plan(self): @@ -1084,15 +1078,12 @@ def sparse_plan(self): elif self._plan is not None: return self._backend.tocsr(self._plan) else: - raise NotImplementedError() + return None @property def lazy_plan(self): """Transport plan, encoded as a symbolic KeOps LazyTensor.""" - if self._lazy_plan is not None: - return self._lazy_plan - else: - raise NotImplementedError() + return self._lazy_plan # Loss values -------------------------------- @@ -1100,26 +1091,17 @@ def lazy_plan(self): def value(self): """Full transport cost, including possible regularization terms and quadratic term for Gromov Wasserstein solutions.""" - if self._value is not None: - return self._value - else: - raise NotImplementedError() + return self._value @property def value_linear(self): """The "minimal" transport cost, i.e. the product between the transport plan and the cost.""" - if self._value_linear is not None: - return self._value_linear - else: - raise NotImplementedError() + return self._value_linear @property def value_quad(self): """The quadratic part of the transport cost for Gromov-Wasserstein solutions.""" - if self._value_quad is not None: - return self._value_quad - else: - raise NotImplementedError() + return self._value_quad # Marginal constraints ------------------------- @property @@ -1129,7 +1111,7 @@ def marginals(self): if self._plan is not None: return self.marginal_a, self.marginal_b else: - raise NotImplementedError() + return None @property def marginal_a(self): @@ -1142,7 +1124,7 @@ def marginal_a(self): nx = self._backend return reduce_lazytensor(lp, nx.sum, axis=1, nx=nx, batch_size=bs) else: - raise NotImplementedError() + return None @property def marginal_b(self): @@ -1155,23 +1137,17 @@ def marginal_b(self): nx = self._backend return reduce_lazytensor(lp, nx.sum, axis=0, nx=nx, batch_size=bs) else: - raise NotImplementedError() + return None @property def status(self): """Optimization status of the solver.""" - if self._status is not None: - return self._status - else: - raise NotImplementedError() + return self._status @property def log(self): """Dictionary containing potential information about the solver.""" - if self._log is not None: - return self._log - else: - raise NotImplementedError() + return self._log # Barycentric mappings ------------------------- # Return the displacement vectors as an array diff --git a/test/gromov/test_fugw.py b/test/gromov/test_fugw.py new file mode 100644 index 000000000..894da1c3b --- /dev/null +++ b/test/gromov/test_fugw.py @@ -0,0 +1,685 @@ +"""Tests for module Fused Unbalanced Gromov-Wasserstein""" + +# Author: Quang Huy Tran +# +# License: MIT License + + +import itertools +import numpy as np +import ot +import pytest +from ot.gromov._unbalanced import fused_unbalanced_gromov_wasserstein, fused_unbalanced_gromov_wasserstein2, fused_unbalanced_across_spaces_divergence + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence", itertools.product(["mm", "lbfgsb"], ["kl", "l2"])) +def test_sanity(nx, unbalanced_solver, divergence): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + M_samp_nx = nx.from_numpy(M_samp) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + reg_m = (100, 50) + eps = 0 + alpha = 0.5 + max_iter_ot = 100 + max_iter = 100 + tol = 1e-7 + tol_ot = 1e-7 + + # test couplings + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + + pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( + C1b, C2b, wx=pb, wy=qb, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp_nx, init_duals=None, init_pi=G0b, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-03) + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + + fugw = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx = fused_unbalanced_gromov_wasserstein2( + C1b, C2b, wx=pb, wy=qb, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp_nx, init_duals=None, init_pi=G0b, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx = nx.to_numpy(fugw_nx) + np.testing.assert_allclose(fugw, fugw_nx, atol=1e-08) + np.testing.assert_allclose(fugw, 0, atol=1e-02) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +def test_init_plans(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C1 /= C1.max() + C2 /= C2.max() + + reg_m = (100, 50) + alpha = 0.5 + max_iter_ot = 5 + max_iter = 5 + tol = 1e-5 + tol_ot = 1e-5 + + pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + + fugw = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx = nx.to_numpy(fugw_nx) + np.testing.assert_allclose(fugw, fugw_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +def test_init_duals(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + dual1, dual2 = nx.from_numpy(np.zeros_like(p), np.zeros_like(q)) + init_duals = (dual1, dual2) + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C1 /= C1.max() + C2 /= C2.max() + + C1, C2, p, q, M_samp = nx.from_numpy(C1, C2, p, q, M_samp) + + reg_m = (100, 50) + alpha = 0.5 + max_iter_ot = 5 + max_iter = 5 + tol = 1e-5 + tol_ot = 1e-5 + + pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=init_duals, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + fugw = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=init_duals, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx = nx.to_numpy(fugw_nx) + np.testing.assert_allclose(fugw, fugw_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +def test_reg_marginals(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C1 /= C1.max() + C2 /= C2.max() + + alpha = 0.5 + max_iter_ot = 5 + max_iter = 5 + tol = 1e-5 + tol_ot = 1e-5 + + reg_m = 100 + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + + list_options = [full_tuple_reg_m, tuple_reg_m, full_list_reg_m, list_reg_m] + + pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=G0, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + for opt in list_options: + pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=opt, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + fugw_nx = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=opt, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx = nx.to_numpy(fugw_nx) + np.testing.assert_allclose(fugw, fugw_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +def test_log(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C1 /= C1.max() + C2 /= C2.max() + + reg_m = (100, 50) + alpha = 0.5 + max_iter_ot = 5 + max_iter = 5 + tol = 1e-5 + tol_ot = 1e-5 + + pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx, log = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=True, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + + fugw = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx, log = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=True, verbose=False + ) + + fugw_nx = nx.to_numpy(fugw_nx) + np.testing.assert_allclose(fugw, fugw_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +def test_marginals(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C1 /= C1.max() + C2 /= C2.max() + + reg_m = (100, 50) + alpha = 0.5 + max_iter_ot = 5 + max_iter = 5 + tol = 1e-5 + tol_ot = 1e-5 + + pi_sample, pi_feature = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx = fused_unbalanced_gromov_wasserstein( + C1, C2, wx=None, wy=None, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + + fugw = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx = fused_unbalanced_gromov_wasserstein2( + C1, C2, wx=None, wy=None, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver=unbalanced_solver, + alpha=alpha, M=M_samp, init_duals=None, init_pi=None, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + fugw_nx = nx.to_numpy(fugw_nx) + np.testing.assert_allclose(fugw, fugw_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +def test_raise_value_error(nx): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + eps = 1e-2 + reg_m = (10, 100) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-6 + tol_ot = 1e-6 + + # raise error of divergence + def fugw_div(divergence): + return fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver="mm", + alpha=0, M=None, init_duals=None, init_pi=G0, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + def fugw_div_nx(divergence): + return fused_unbalanced_gromov_wasserstein( + C1b, C2b, wx=pb, wy=qb, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver="mm", + alpha=0, M=None, init_duals=None, init_pi=G0b, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + np.testing.assert_raises(NotImplementedError, fugw_div, "div_not_existed") + np.testing.assert_raises(NotImplementedError, fugw_div_nx, "div_not_existed") + + # raise error of solver + def fugw_solver(unbalanced_solver): + return fused_unbalanced_gromov_wasserstein( + C1, C2, wx=p, wy=q, reg_marginals=reg_m, epsilon=eps, + divergence="kl", unbalanced_solver=unbalanced_solver, + alpha=0, M=None, init_duals=None, init_pi=G0, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + def fugw_solver_nx(unbalanced_solver): + return fused_unbalanced_gromov_wasserstein( + C1b, C2b, wx=pb, wy=qb, reg_marginals=reg_m, epsilon=eps, + divergence="kl", unbalanced_solver=unbalanced_solver, + alpha=0, M=None, init_duals=None, init_pi=G0b, max_iter=max_iter, + tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + np.testing.assert_raises(NotImplementedError, fugw_solver, "solver_not_existed") + np.testing.assert_raises(NotImplementedError, fugw_solver_nx, "solver_not_existed") + + +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +def test_fused_unbalanced_across_spaces_divergence_wrong_reg_type(nx, unbalanced_solver, divergence, eps): + + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + x, y = nx.from_numpy(x), nx.from_numpy(y) + + reg_m = 100 + + def reg_type(reg_type): + return fused_unbalanced_across_spaces_divergence( + X=x, Y=y, reg_marginals=reg_m, + epsilon=eps, reg_type=reg_type, + divergence=divergence, unbalanced_solver=unbalanced_solver + ) + + np.testing.assert_raises(NotImplementedError, reg_type, "reg_type_not_existed") + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps, reg_type", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2], ["independent", "joint"])) +def test_fused_unbalanced_across_spaces_divergence_log(nx, unbalanced_solver, divergence, eps, reg_type): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + M_feat = np.ones((2, 2)) + np.fill_diagonal(M_feat, 0) + M_samp_nx, M_feat_nx = nx.from_numpy(M_samp, M_feat) + + reg_m = (10, 5) + alpha = (0.1, 0.2) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + # test couplings + pi_sample, pi_feature = fused_unbalanced_across_spaces_divergence( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, reg_type=reg_type, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx, log = fused_unbalanced_across_spaces_divergence( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, reg_type=reg_type, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=True, verbose=False + ) + + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("reg_type", ["independent", "joint"]) +def test_fused_unbalanced_across_spaces_divergence_warning(nx, reg_type): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + unbalanced_solver = "mm" + divergence = "kl" + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + M_feat = np.ones((2, 2)) + np.fill_diagonal(M_feat, 0) + M_samp_nx, M_feat_nx = nx.from_numpy(M_samp, M_feat) + + reg_m = (1e6, 1e6) + eps = 1e-2 + alpha = (0.1, 0.2) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + def raise_warning(): + return fused_unbalanced_across_spaces_divergence( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, reg_type=reg_type, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + np.testing.assert_raises(ValueError, raise_warning) diff --git a/test/gromov/test_utils.py b/test/gromov/test_utils.py index 70894fcfc..b0338c84a 100644 --- a/test/gromov/test_utils.py +++ b/test/gromov/test_utils.py @@ -4,6 +4,7 @@ # # License: MIT License +import itertools import numpy as np import pytest @@ -111,3 +112,39 @@ def test_semirelaxed_init_plan(nx): T = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method='fluid') Tb = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method='fluid') np.testing.assert_allclose(T, Tb) + + +@pytest.mark.parametrize("divergence", ["kl", "l2"]) +def test_div_between_product(nx, divergence): + ns = 5 + nt = 10 + + ps, pt = ot.unif(ns), ot.unif(nt) + ps, pt = nx.from_numpy(ps), nx.from_numpy(pt) + ps1, pt1 = 2 * ps, 2 * pt + + res_nx = ot.gromov.div_between_product(ps, pt, ps1, pt1, divergence, nx=nx) + res = ot.gromov.div_between_product(ps, pt, ps1, pt1, divergence, nx=None) + + np.testing.assert_allclose(res_nx, res, atol=1e-06) + + +@pytest.mark.parametrize("divergence, mass", itertools.product(["kl", "l2"], [True, False])) +def test_div_to_product(nx, divergence, mass): + ns = 5 + nt = 10 + + a, b = ot.unif(ns), ot.unif(nt) + a, b = nx.from_numpy(a), nx.from_numpy(b) + + pi = 2 * a[:, None] * b[None, :] + pi1, pi2 = nx.sum(pi, 1), nx.sum(pi, 0) + + res = ot.gromov.div_to_product(pi, a, b, pi1=None, pi2=None, divergence=divergence, mass=mass, nx=None) + res1 = ot.gromov.div_to_product(pi, a, b, pi1=None, pi2=None, divergence=divergence, mass=mass, nx=nx) + res2 = ot.gromov.div_to_product(pi, a, b, pi1=pi1, pi2=pi2, divergence=divergence, mass=mass, nx=None) + res3 = ot.gromov.div_to_product(pi, a, b, pi1=pi1, pi2=pi2, divergence=divergence, mass=mass, nx=nx) + + np.testing.assert_allclose(res1, res, atol=1e-06) + np.testing.assert_allclose(res2, res, atol=1e-06) + np.testing.assert_allclose(res3, res, atol=1e-06) diff --git a/test/test_solvers.py b/test/test_solvers.py index 16e6df295..61dda87a7 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -59,10 +59,15 @@ def assert_allclose_sol(sol1, sol2): nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() for attr in lst_attr: - try: - np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) - except NotImplementedError: - pass + if getattr(sol1, attr) is not None and getattr(sol2, attr) is not None: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr)), equal_nan=True) + except NotImplementedError: + pass + elif getattr(sol1, attr) is None and getattr(sol2, attr) is None: + return True + else: + return False def test_solve(nx): @@ -77,14 +82,15 @@ def test_solve(nx): b = ot.utils.unif(n_samples_t) M = ot.dist(x, y) + reg = 1e-1 # solve unif weights - sol0 = ot.solve(M) + sol0 = ot.solve(M, reg=reg) print(sol0) # solve signe weights - sol = ot.solve(M, a, b) + sol = ot.solve(M, a, b, reg=reg) # check some attributes sol.potentials @@ -92,7 +98,8 @@ def test_solve(nx): sol.marginals sol.status - assert_allclose_sol(sol0, sol) + # print("dual = {}".format(sol.potentials)) + # assert_allclose_sol(sol0, sol) # solve in backend ab, bb, Mb = nx.from_numpy(a, b, M) diff --git a/test/test_ucoot.py b/test/test_ucoot.py new file mode 100644 index 000000000..fcace4178 --- /dev/null +++ b/test/test_ucoot.py @@ -0,0 +1,795 @@ +"""Tests for module Unbalanced Co-Optimal Transport""" + +# Author: Quang Huy Tran +# +# License: MIT License + + +import itertools +import numpy as np +import ot +import pytest +from ot.gromov._unbalanced import unbalanced_co_optimal_transport, unbalanced_co_optimal_transport2 + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence", itertools.product(["mm", "lbfgsb"], ["kl", "l2"])) +def test_sanity(nx, unbalanced_solver, divergence): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + reg_m = (10, 5) + eps = 0 + max_iter_ot = 200 + max_iter = 200 + tol = 1e-7 + tol_ot = 1e-7 + + # test couplings + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=0, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=0, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-05) + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-05) + + # test divergence + ucoot = unbalanced_co_optimal_transport2( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=0, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = unbalanced_co_optimal_transport2( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=0, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = nx.to_numpy(ucoot_nx) + np.testing.assert_allclose(ucoot, ucoot_nx, atol=1e-08) + np.testing.assert_allclose(ucoot, 0, atol=1e-06) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +def test_init_plans(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + G0_samp = px_s[:, None] * py_s[None, :] + G0_feat = px_f[:, None] * py_f[None, :] + + xs_nx, xt_nx, G0_samp_nx, G0_feat_nx = nx.from_numpy(xs, xt, G0_samp, G0_feat) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + reg_m = (1, 5) + alpha = (0.1, 0.2) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + # test couplings + pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=(G0_samp_nx, G0_feat_nx), init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-03) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-03) + + # test divergence + ucoot = unbalanced_co_optimal_transport2( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = unbalanced_co_optimal_transport2( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=(G0_samp_nx, G0_feat_nx), init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = nx.to_numpy(ucoot_nx) + np.testing.assert_allclose(ucoot, ucoot_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +def test_init_duals(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + init_duals_samp = nx.from_numpy(np.zeros(n_samples), np.zeros(n_samples)) + init_duals_feat = nx.from_numpy(np.zeros(2), np.zeros(2)) + init_duals = (init_duals_samp, init_duals_feat) + + reg_m = (10, 5) + alpha = (0.1, 0.2) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + # test couplings + pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=init_duals, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-03) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-03) + + # test divergence + ucoot = unbalanced_co_optimal_transport2( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = unbalanced_co_optimal_transport2( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=init_duals, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = nx.to_numpy(ucoot_nx) + np.testing.assert_allclose(ucoot, ucoot_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +def test_linear_part(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + M_feat = np.ones((2, 2)) + np.fill_diagonal(M_feat, 0) + M_samp_nx, M_feat_nx = nx.from_numpy(M_samp, M_feat) + + reg_m = (10, 5) + alpha = (0.1, 0.2) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + # test couplings + pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=M_samp, M_feat=M_feat, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + ucoot = unbalanced_co_optimal_transport2( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=M_samp, M_feat=M_feat, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = unbalanced_co_optimal_transport2( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = nx.to_numpy(ucoot_nx) + np.testing.assert_allclose(ucoot, ucoot_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +def test_reg_marginals(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + alpha = (0.1, 0.2) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + reg_m = 100 + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + + list_options = [full_tuple_reg_m, tuple_reg_m, full_list_reg_m, list_reg_m] + + # test couplings + pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + # test divergence + ucoot = unbalanced_co_optimal_transport2( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + for opt in list_options: + + pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=opt, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + ucoot_nx = unbalanced_co_optimal_transport2( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=opt, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + method_sinkhorn="sinkhorn", log=False, verbose=False + ) + + ucoot_nx = nx.to_numpy(ucoot_nx) + np.testing.assert_allclose(ucoot, ucoot_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, alpha", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +def test_eps(nx, unbalanced_solver, divergence, alpha): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + reg_m = (10, 5) + alpha = (0.1, 0.2) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + eps = 1 + full_list_eps = [eps, eps] + full_tuple_eps = (eps, eps) + tuple_eps, list_eps = (eps), [eps] + + list_options = [full_list_eps, full_tuple_eps, tuple_eps, list_eps] + + # test couplings + pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + # test divergence + ucoot = unbalanced_co_optimal_transport2( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + for opt in list_options: + + pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=opt, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + ucoot_nx = unbalanced_co_optimal_transport2( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=opt, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = nx.to_numpy(ucoot_nx) + np.testing.assert_allclose(ucoot, ucoot_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1e-2])) +def test_alpha(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + # linear part + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + M_feat = np.ones((2, 2)) + np.fill_diagonal(M_feat, 0) + M_samp_nx, M_feat_nx = nx.from_numpy(M_samp, M_feat) + + reg_m = (10, 5) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + alpha = 1 + full_list_alpha = [alpha, alpha] + full_tuple_alpha = (alpha, alpha) + tuple_alpha, list_alpha = (alpha), [alpha] + + list_options = [full_list_alpha, full_tuple_alpha, tuple_alpha, list_alpha] + + # test couplings + pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=M_samp, M_feat=M_feat, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + # test divergence + ucoot = unbalanced_co_optimal_transport2( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=M_samp, M_feat=M_feat, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + for opt in list_options: + pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=opt, + M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + ucoot_nx = unbalanced_co_optimal_transport2( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=opt, + M_samp=M_samp_nx, M_feat=M_feat_nx, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = nx.to_numpy(ucoot_nx) + np.testing.assert_allclose(ucoot, ucoot_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +def test_log(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + reg_m = (10, 5) + alpha = (0.1, 0.2) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + # test couplings + pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx, log = unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=True, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + ucoot = unbalanced_co_optimal_transport2( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = unbalanced_co_optimal_transport2( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = nx.to_numpy(ucoot_nx) + np.testing.assert_allclose(ucoot, ucoot_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +@pytest.mark.parametrize("unbalanced_solver, divergence, eps", itertools.product(["sinkhorn", "sinkhorn_log", "mm", "lbfgsb"], ["kl", "l2"], [0, 1])) +def test_marginals(nx, unbalanced_solver, divergence, eps): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + reg_m = (10, 5) + alpha = (0.1, 0.2) + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + # test couplings + pi_sample, pi_feature = unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + pi_sample_nx, pi_feature_nx = unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample_nx, pi_sample, atol=1e-06) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-06) + + # test divergence + ucoot = unbalanced_co_optimal_transport2( + X=xs, Y=xt, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = unbalanced_co_optimal_transport2( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, wy_feat=py_f_nx, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver=unbalanced_solver, alpha=alpha, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + ucoot_nx = nx.to_numpy(ucoot_nx) + np.testing.assert_allclose(ucoot, ucoot_nx, atol=1e-08) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tensorflow backend") +def test_raise_value_error(nx): + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + xs_nx, xt_nx = nx.from_numpy(xs, xt) + px_s_nx, px_f_nx, py_s_nx, py_f_nx = nx.from_numpy(px_s, px_f, py_s, py_f) + + reg_m = (10, 5) + eps = 0 + max_iter_ot = 5 + max_iter = 5 + tol = 1e-7 + tol_ot = 1e-7 + + # raise error of divergence + def ucoot_div(divergence): + return unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence=divergence, + unbalanced_solver="mm", alpha=0, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + def ucoot_div_nx(divergence): + return unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, + wy_feat=py_f_nx, reg_marginals=reg_m, epsilon=eps, + divergence=divergence, unbalanced_solver="mm", alpha=0, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + np.testing.assert_raises(NotImplementedError, ucoot_div, "div_not_existed") + np.testing.assert_raises(NotImplementedError, ucoot_div_nx, "div_not_existed") + + # raise error of solver + def ucoot_solver(unbalanced_solver): + return unbalanced_co_optimal_transport( + X=xs, Y=xt, wx_samp=px_s, wx_feat=px_f, wy_samp=py_s, wy_feat=py_f, + reg_marginals=reg_m, epsilon=eps, divergence="kl", + unbalanced_solver=unbalanced_solver, alpha=0, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + def ucoot_solver_nx(unbalanced_solver): + return unbalanced_co_optimal_transport( + X=xs_nx, Y=xt_nx, wx_samp=px_s_nx, wx_feat=px_f_nx, wy_samp=py_s_nx, + wy_feat=py_f_nx, reg_marginals=reg_m, epsilon=eps, + divergence="kl", unbalanced_solver=unbalanced_solver, alpha=0, + M_samp=None, M_feat=None, init_pi=None, init_duals=None, + max_iter=max_iter, tol=tol, max_iter_ot=max_iter_ot, tol_ot=tol_ot, + log=False, verbose=False + ) + + np.testing.assert_raises(NotImplementedError, ucoot_solver, "solver_not_existed") + np.testing.assert_raises(NotImplementedError, ucoot_solver_nx, "solver_not_existed") diff --git a/test/test_utils.py b/test/test_utils.py index 0801337cb..82f514574 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -419,9 +419,7 @@ def test_OTResult(): # tets get citation print(res.citation) - lst_attributes = ['a_to_b', - 'b_to_a', - 'lazy_plan', + lst_attributes = ['lazy_plan', 'marginal_a', 'marginal_b', 'marginals', @@ -436,6 +434,11 @@ def test_OTResult(): 'value_quad', 'log'] for at in lst_attributes: + print(at) + assert getattr(res, at) is None + + list_not_implemented = ['a_to_b', 'b_to_a'] + for at in list_not_implemented: print(at) with pytest.raises(NotImplementedError): getattr(res, at)