Skip to content

Commit

Permalink
[DA] Sinkhorn L1L2 transport to work on JAX (#587)
Browse files Browse the repository at this point in the history
* Draft sinkhorn_l1l2_transport to work on JAX

* Move label_to_masks in utils

* Move nan_to_num to backend

* Proper test case for semi-supervised DA

* Test case for label to mask computation

* Simplified axis operations for labels

* Allow JAX backend for BaseEstimator

* Label normalization performs copy only when necessary

* Fix comment regarding label transformation

* Update RELEASES

* Additional backend tests for nan_to_num

* min(unique(y)) === min(y)

* Avoid catching all warnings as JAX throws deprecation

* No need to import warnings module

---------

Co-authored-by: Rémi Flamary <[email protected]>
  • Loading branch information
kachayev and rflamary authored Dec 22, 2023
1 parent acd84ed commit 9ddb690
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 70 deletions.
5 changes: 5 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Releases

## Next Release

#### New features
+ Domain adaptation method `SinkhornL1l2Transport` now supports JAX backend (PR #587)

## 0.9.2dev

#### New features
Expand Down
27 changes: 27 additions & 0 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,14 @@ def matmul(self, a, b):
"""
raise NotImplementedError()

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
r"""
Replace NaN with zero and infinity with large finite numbers or with the numbers defined by the user.
See: https://numpy.org/doc/stable/reference/generated/numpy.nan_to_num.html#numpy.nan_to_num
"""
raise NotImplementedError()


class NumpyBackend(Backend):
"""
Expand Down Expand Up @@ -1392,6 +1400,9 @@ def detach(self, *args):
def matmul(self, a, b):
return np.matmul(a, b)

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
return np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)


_register_backend_implementation(NumpyBackend)

Expand Down Expand Up @@ -1762,6 +1773,9 @@ def detach(self, *args):
def matmul(self, a, b):
return jnp.matmul(a, b)

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
return jnp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)


if jax:
# Only register jax backend if it is installed
Expand Down Expand Up @@ -2250,6 +2264,10 @@ def detach(self, *args):
def matmul(self, a, b):
return torch.matmul(a, b)

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
out = None if copy else x
return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf, out=out)


if torch:
# Only register torch backend if it is installed
Expand Down Expand Up @@ -2647,6 +2665,9 @@ def detach(self, *args):
def matmul(self, a, b):
return cp.matmul(a, b)

def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
return cp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)


if cp:
# Only register cp backend if it is installed
Expand Down Expand Up @@ -3070,6 +3091,12 @@ def detach(self, *args):
def matmul(self, a, b):
return tnp.matmul(a, b)

# todo(okachaiev): replace this with a more reasonable implementation
def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
x = self.to_numpy(x)
x = np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf)
return self.from_numpy(x)


if tf:
# Only register tensorflow backend if it is installed
Expand Down
73 changes: 24 additions & 49 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
from .utils import list_to_array, check_params, BaseEstimator, deprecated
from .utils import BaseEstimator, check_params, deprecated, labels_to_masks, list_to_array
from .unbalanced import sinkhorn_unbalanced
from .gaussian import empirical_bures_wasserstein_mapping, empirical_gaussian_gromov_wasserstein_mapping
from .optim import cg
Expand Down Expand Up @@ -499,18 +499,12 @@ class label
if self.limit_max != np.infty:
self.limit_max = self.limit_max * nx.max(self.cost_)

# assumes labeled source samples occupy the first rows
# and labeled target samples occupy the first columns
classes = [c for c in nx.unique(ys) if c != -1]
for c in classes:
idx_s = nx.where((ys != c) & (ys != -1))
idx_t = nx.where(yt == c)

# all the coefficients corresponding to a source sample
# and a target sample :
# with different labels get a infinite
for j in idx_t[0]:
self.cost_[idx_s[0], j] = self.limit_max
# zeros where source label is missing (masked with -1)
missing_labels = ys + nx.ones(ys.shape, type_as=ys)
missing_labels = nx.repeat(missing_labels[:, None], ys.shape[0], 1)
# zeros where labels match
label_match = ys[:, None] - yt[None, :]
self.cost_ = nx.maximum(self.cost_, nx.abs(label_match) * nx.abs(missing_labels) * self.limit_max)

# distribution estimation
self.mu_s = self.distribution_estimation(Xs)
Expand Down Expand Up @@ -581,12 +575,11 @@ class label
if check_params(Xs=Xs):

if nx.array_equal(self.xs_, Xs):

# perform standard barycentric mapping
transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

# compute transported samples
transp_Xs = nx.dot(transp, self.xt_)
Expand All @@ -604,9 +597,8 @@ class label
idx = nx.argmin(D0, axis=1)

# transport the source samples
transp = self.coupling_ / nx.sum(
self.coupling_, axis=1)[:, None]
transp[~ nx.isfinite(transp)] = 0
transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
transp_Xs_ = nx.dot(transp, self.xt_)

# define the transported points
Expand Down Expand Up @@ -645,23 +637,16 @@ def transform_labels(self, ys=None):

# check the necessary inputs parameters are here
if check_params(ys=ys):

ysTemp = label_normalization(nx.copy(ys))
classes = nx.unique(ysTemp)
n = len(classes)
D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_)

# perform label propagation
transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0

for c in classes:
D1[int(c), ysTemp == c] = 1
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

# compute propagated labels
transp_ys = nx.dot(D1, transp)
labels = label_normalization(ys)
masks = labels_to_masks(labels, nx=nx, type_as=transp)
transp_ys = nx.dot(masks.T, transp)

return transp_ys.T

Expand Down Expand Up @@ -697,12 +682,11 @@ class label
if check_params(Xt=Xt):

if nx.array_equal(self.xt_, Xt):

# perform standard barycentric mapping
transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]

# set nans to 0
transp_[~ nx.isfinite(transp_)] = 0
transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0)

# compute transported samples
transp_Xt = nx.dot(transp_, self.xs_)
Expand All @@ -719,9 +703,8 @@ class label
idx = nx.argmin(D0, axis=1)

# transport the target samples
transp_ = self.coupling_.T / nx.sum(
self.coupling_, 0)[:, None]
transp_[~ nx.isfinite(transp_)] = 0
transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]
transp_ = nx.nan_to_num(transp_, nan=0, posinf=0, neginf=0)
transp_Xt_ = nx.dot(transp_, self.xs_)

# define the transported points
Expand Down Expand Up @@ -750,23 +733,15 @@ def inverse_transform_labels(self, yt=None):

# check the necessary inputs parameters are here
if check_params(yt=yt):

ytTemp = label_normalization(nx.copy(yt))
classes = nx.unique(ytTemp)
n = len(classes)
D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_)

# perform label propagation
transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]

# set nans to 0
transp[~ nx.isfinite(transp)] = 0
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

for c in classes:
D1[int(c), ytTemp == c] = 1

# compute propagated samples
transp_ys = nx.dot(D1, transp.T)
# compute propagated labels
labels = label_normalization(yt)
masks = labels_to_masks(labels, nx=nx, type_as=transp)
transp_ys = nx.dot(masks.T, transp.T)

return transp_ys.T

Expand Down Expand Up @@ -2151,7 +2126,7 @@ def transform_labels(self, ys=None):
type_as=ys[0]
)
for i in range(len(ys)):
ysTemp = label_normalization(nx.copy(ys[i]))
ysTemp = label_normalization(ys[i])
classes = nx.unique(ysTemp)
n = len(classes)
ns = len(ysTemp)
Expand Down Expand Up @@ -2194,7 +2169,7 @@ def inverse_transform_labels(self, yt=None):
# check the necessary inputs parameters are here
if check_params(yt=yt):
transp_ys = []
ytTemp = label_normalization(nx.copy(yt))
ytTemp = label_normalization(yt)
classes = nx.unique(ytTemp)
n = len(classes)
D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0])
Expand Down
45 changes: 35 additions & 10 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def is_all_finite(*args):
return all(not nx.any(~nx.isfinite(arg)) for arg in args)


def label_normalization(y, start=0):
def label_normalization(y, start=0, nx=None):
r""" Transform labels to start at a given value
Parameters
Expand All @@ -399,18 +399,45 @@ def label_normalization(y, start=0):
The vector of labels to be normalized.
start : int
Desired value for the smallest label in :math:`\mathbf{y}` (default=0)
nx : Backend, optional
Backend to perform computations on. If omitted, the backend defaults to that of `y`.
Returns
-------
y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
nx = get_backend(y)
if nx is None:
nx = get_backend(y)
diff = nx.min(y) - start
return y if diff == 0 else (y - diff)


def labels_to_masks(y, type_as=None, nx=None):
r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks.
Parameters
----------
y : array-like, shape (n_samples, )
The vector of labels.
type_as : array_like
Array of the same type of the expected output.
nx : Backend, optional
Backend to perform computations on. If omitted, the backend defaults to that of `y`.
diff = nx.min(nx.unique(y)) - start
if diff != 0:
y -= diff
return y
Returns
-------
masks : array-like, shape (n_samples, n_labels)
The (n_samples, n_labels) matrix of label masks.
"""
if nx is None:
nx = get_backend(y)
if type_as is None:
type_as = y
labels_u, labels_idx = nx.unique(y, return_inverse=True)
n_labels = labels_u.shape[0]
masks = nx.eye(n_labels, type_as=type_as)[labels_idx]
return masks


def parmap(f, X, nprocs="default"):
Expand Down Expand Up @@ -755,10 +782,8 @@ def _get_backend(self, *arrays):
nx = get_backend(
*[input_ for input_ in arrays if input_ is not None]
)
if nx.__name__ in ("jax", "tf"):
raise TypeError(
"""JAX or TF arrays have been received but domain
adaptation does not support those backend.""")
if nx.__name__ in ("tf",):
raise TypeError("Domain adaptation does not support TF backend.")
self.nx = nx
return nx

Expand Down
7 changes: 7 additions & 0 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def test_empty_backend():
nx.detach(M)
with pytest.raises(NotImplementedError):
nx.matmul(M, M.T)
with pytest.raises(NotImplementedError):
nx.nan_to_num(M)


def test_func_backends(nx):
Expand Down Expand Up @@ -667,6 +669,11 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("matmul broadcast")

vec = nx.from_numpy(np.array([1, np.nan, -1]))
vec = nx.nan_to_num(vec, nan=0)
lst_b.append(nx.to_numpy(vec))
lst_name.append("nan_to_num")

assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(
Expand Down
19 changes: 8 additions & 11 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
from numpy.testing import assert_allclose, assert_equal
import pytest
import warnings

import ot
from ot.datasets import make_data_classif
Expand Down Expand Up @@ -158,7 +157,6 @@ def test_sinkhorn_lpl1_transport_class(nx):
assert mass_semi == 0, "semisupervised mode not working"


@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_sinkhorn_l1l2_transport_class(nx):
"""test_sinkhorn_transport
Expand All @@ -169,15 +167,16 @@ def test_sinkhorn_l1l2_transport_class(nx):

Xs, ys = make_data_classif('3gauss', ns, random_state=42)
Xt, yt = make_data_classif('3gauss2', nt, random_state=43)
# prepare semi-supervised labels
yt_semi = np.copy(yt)
yt_semi[np.arange(0, nt, 2)] = -1

Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
Xs, ys, Xt, yt, yt_semi = nx.from_numpy(Xs, ys, Xt, yt, yt_semi)

otda = ot.da.SinkhornL1l2Transport(max_inner_iter=500)
otda.fit(Xs=Xs, ys=ys, Xt=Xt)

# test its computed
with warnings.catch_warnings():
warnings.simplefilter("error")
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(otda, "cost_")
assert hasattr(otda, "coupling_")
assert hasattr(otda, "log_")
Expand Down Expand Up @@ -234,7 +233,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
n_unsup = nx.sum(otda_unsup.cost_)

otda_semi = ot.da.SinkhornL1l2Transport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt_semi)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = nx.sum(otda_semi.cost_)

Expand All @@ -243,11 +242,9 @@ def test_sinkhorn_l1l2_transport_class(nx):

# check that the coupling forbids mass transport between labeled source
# and labeled target samples
mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
mass_semi = nx.sum(otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)),
rtol=1e-9, atol=1e-9)
assert_allclose(nx.to_numpy(mass_semi), np.zeros_like(mass_semi), rtol=1e-9, atol=1e-9)

# check everything runs well with log=True
otda = ot.da.SinkhornL1l2Transport(log=True)
Expand Down
Loading

0 comments on commit 9ddb690

Please sign in to comment.