Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559)
+ Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551)
+ New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563)
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
245 changes: 178 additions & 67 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings

from ..bregman import sinkhorn
from ..utils import dist, list_to_array, check_random_state, unif
from ..utils import dist, UndefinedParameter, list_to_array, check_random_state, unif
from ..backend import get_backend

from ._utils import init_matrix, gwloss, gwggrad
Expand Down Expand Up @@ -345,8 +345,9 @@ def entropic_gromov_wasserstein2(

def entropic_gromov_barycenters(
N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss',
epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9, warmstartT=False,
verbose=False, log=False, init_C=None, random_state=None, **kwargs):
epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9,
stop_criterion='barycenter', warmstartT=False, verbose=False,
log=False, init_C=None, random_state=None, **kwargs):
r"""
Returns the Gromov-Wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
estimated using Gromov-Wasserstein transports from Sinkhorn projections.
Expand Down Expand Up @@ -388,6 +389,10 @@ def entropic_gromov_barycenters(
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
stop_criterion : str, optional. Default is 'barycenter'.
Convergence criterion taking values in ['barycenter', 'loss']. If set to 'barycenter'
uses absolute norm variations of estimated barycenters. Else if set to 'loss'
uses the relative variations of the loss.
warmstartT: bool, optional
Either to perform warmstart of transport plans in the successive
gromov-wasserstein transport problems.
Expand All @@ -407,7 +412,11 @@ def entropic_gromov_barycenters(
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
log : dict
Log dictionary of error during iterations. Return only if `log=True` in parameters.
Only returned when log=True. It contains the keys:

- :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
- :math:`\mathbf{p}`: (`N`,) barycenter weights
- values used in convergence evaluation.

References
----------
Expand All @@ -418,6 +427,9 @@ def entropic_gromov_barycenters(
if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

if stop_criterion not in ['barycenter', 'loss']:
raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.")

Cs = list_to_array(*Cs)
arr = [*Cs]
if ps is not None:
Expand Down Expand Up @@ -446,45 +458,75 @@ def entropic_gromov_barycenters(
C = init_C

cpt = 0
err = 1

error = []
err = 1e15 # either the error on 'barycenter' or 'loss'

if warmstartT:
T = [None] * S

if stop_criterion == 'barycenter':
inner_log = False
else:
inner_log = True
curr_loss = 1e15

if log:
log_ = {}
log_['err'] = []
if stop_criterion == 'loss':
log_['loss'] = []

while (err > tol) and (cpt < max_iter):
Cprev = C
if stop_criterion == 'barycenter':
Cprev = C
else:
prev_loss = curr_loss

# get transport plans
if warmstartT:
T = [entropic_gromov_wasserstein(
res = [entropic_gromov_wasserstein(
C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, T[s],
max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)]
else:
T = [entropic_gromov_wasserstein(
res = [entropic_gromov_wasserstein(
C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, None,
max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)]
if stop_criterion == 'barycenter':
T = res
else:
T = [output[0] for output in res]
curr_loss = np.sum([output[1]['gw_dist'] for output in res])

# update barycenters
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)
elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)

if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
# update convergence criterion
if stop_criterion == 'barycenter':
err = nx.norm(C - Cprev)
error.append(err)
if log:
log_['err'].append(err)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
else:
err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
if log:
log_['loss'].append(curr_loss)
log_['err'].append(err)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))

cpt += 1

if log:
return C, {"err": error}
log_['T'] = T
log_['p'] = p

return C, log_
else:
return C

Expand Down Expand Up @@ -838,8 +880,9 @@ def entropic_fused_gromov_wasserstein2(
def entropic_fused_gromov_barycenters(
N, Ys, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss',
epsilon=0.1, symmetric=True, alpha=0.5, max_iter=1000, tol=1e-9,
warmstartT=False, verbose=False, log=False, init_C=None, init_Y=None,
random_state=None, **kwargs):
stop_criterion='barycenter', warmstartT=False, verbose=False,
log=False, init_C=None, init_Y=None, fixed_structure=False,
fixed_features=False, random_state=None, **kwargs):
r"""
Returns the Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}`
estimated using Fused Gromov-Wasserstein transports from Sinkhorn projections.
Expand Down Expand Up @@ -886,6 +929,10 @@ def entropic_fused_gromov_barycenters(
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
stop_criterion : str, optional. Default is 'barycenter'.
Stop criterion taking values in ['barycenter', 'loss']. If set to 'barycenter'
uses absolute norm variations of estimated barycenters. Else if set to 'loss'
uses the relative variations of the loss.
warmstartT: bool, optional
Either to perform warmstart of transport plans in the successive
fused gromov-wasserstein transport problems.
Expand All @@ -898,6 +945,10 @@ def entropic_fused_gromov_barycenters(
init_Y : array-like, shape (N,d), optional
Initialization for the barycenters' features. If not set a
random init is used.
fixed_structure : bool, optional
Whether to fix the structure of the barycenter during the updates.
fixed_features : bool, optional
Whether to fix the feature of the barycenter during the updates
random_state : int or RandomState instance, optional
Fix the seed for reproducibility
**kwargs: dict
Expand All @@ -910,7 +961,12 @@ def entropic_fused_gromov_barycenters(
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated as Y's rows)
log : dict
Log dictionary of error during iterations. Return only if `log=True` in parameters.
Only returned when log=True. It contains the keys:

- :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
- :math:`\mathbf{p}`: (`N`,) barycenter weights
- :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
- values used in convergence evaluation.

References
----------
Expand All @@ -926,6 +982,9 @@ def entropic_fused_gromov_barycenters(
if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

if stop_criterion not in ['barycenter', 'loss']:
raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.")

Cs = list_to_array(*Cs)
Ys = list_to_array(*Ys)
arr = [*Cs, *Ys]
Expand All @@ -945,67 +1004,108 @@ def entropic_fused_gromov_barycenters(

d = Ys[0].shape[1] # dimension on the node features

# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
C = nx.from_numpy(C, type_as=p)
# Initialization of C : random euclidean distance matrix (if not provided by user)
if fixed_structure:
if init_C is None:
raise UndefinedParameter('If C is fixed it must be initialized')
else:
C = init_C
else:
C = init_C
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C = nx.from_numpy(C, type_as=ps[0])
else:
C = init_C

# Initialization of Y
if init_Y is None:
Y = nx.zeros((N, d), type_as=ps[0])
if fixed_features:
if init_Y is None:
raise UndefinedParameter('If Y is fixed it must be initialized')
else:
Y = init_Y
else:
Y = init_Y
if init_Y is None:
Y = nx.zeros((N, d), type_as=ps[0])

if warmstartT:
T = [None] * S
else:
Y = init_Y

Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]

if warmstartT:
T = [None] * S

cpt = 0
err = 1

err_feature = 1
err_structure = 1
if stop_criterion == 'barycenter':
inner_log = False
err_feature = 1e15
err_structure = 1e15
err_rel_loss = 0.

else:
inner_log = True
err_feature = 0.
err_structure = 0.
curr_loss = 1e15
err_rel_loss = 1e15

if log:
log_ = {}
log_['err_feature'] = []
log_['err_structure'] = []
log_['Ts_iter'] = []
if stop_criterion == 'barycenter':
log_['err_feature'] = []
log_['err_structure'] = []
log_['Ts_iter'] = []
Copy link
Collaborator Author

@cedricvincentcuaz cedricvincentcuaz Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am really not convinced that Ts_iter is useful to anyone, knowing that it uses a lot of memory relatively ^^'

else:
log_['loss'] = []
log_['err_rel_loss'] = []

while (err > tol) and (cpt < max_iter):
Cprev = C
Yprev = Y
while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter):
if stop_criterion == 'barycenter':
Cprev = C
Yprev = Y
else:
prev_loss = curr_loss

# get transport plans
if warmstartT:
T = [entropic_fused_gromov_wasserstein(
res = [entropic_fused_gromov_wasserstein(
Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha,
T[s], max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
T[s], max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)]

else:
T = [entropic_fused_gromov_wasserstein(
res = [entropic_fused_gromov_wasserstein(
Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha,
None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
None, max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)]

if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)
elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)

Ys_temp = [y.T for y in Ys]
Y = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]

if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err_feature = nx.norm(Y - nx.reshape(Yprev, (N, d)))
err_structure = nx.norm(C - Cprev)
if stop_criterion == 'barycenter':
T = res
else:
T = [output[0] for output in res]
curr_loss = np.sum([output[1]['fgw_dist'] for output in res])

# update barycenters
if not fixed_features:
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if not fixed_structure:
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)

# update convergence criterion
if stop_criterion == 'barycenter':
err_feature, err_structure = 0., 0.
if not fixed_features:
err_feature = nx.norm(Y - Yprev)
if not fixed_structure:
err_structure = nx.norm(C - Cprev)
if log:
log_['err_feature'].append(err_feature)
log_['err_structure'].append(err_structure)
Expand All @@ -1017,14 +1117,25 @@ def entropic_fused_gromov_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_structure))
print('{:5d}|{:8e}|'.format(cpt, err_feature))
else:
err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
if log:
log_['loss'].append(curr_loss)
log_['err_rel_loss'].append(err_rel_loss)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_rel_loss))

cpt += 1

if log:
log_['T'] = T # from target to Ys
log_['T'] = T
log_['p'] = p
log_['Ms'] = Ms

if log:
return Y, C, log_
else:
return Y, C
Loading