@@ -207,6 +207,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
207
207
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
208
208
weights (histograms, both sum to 1)
209
209
210
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
211
+ the entropic contribution).
212
+
210
213
.. note:: This function is backend-compatible and will work on arrays
211
214
from all compatible backends.
212
215
@@ -320,15 +323,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
320
323
if len (b .shape ) < 2 :
321
324
if method .lower () == 'sinkhorn' :
322
325
res = sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
323
- stopThr = stopThr , verbose = verbose , log = log ,
326
+ stopThr = stopThr , verbose = verbose ,
327
+ log = log , warn = warn ,
324
328
** kwargs )
325
329
elif method .lower () == 'sinkhorn_log' :
326
330
res = sinkhorn_log (a , b , M , reg , numItermax = numItermax ,
327
- stopThr = stopThr , verbose = verbose , log = log ,
331
+ stopThr = stopThr , verbose = verbose ,
332
+ log = log , warn = warn ,
328
333
** kwargs )
329
334
elif method .lower () == 'sinkhorn_stabilized' :
330
335
res = sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
331
- stopThr = stopThr , verbose = verbose , log = log ,
336
+ stopThr = stopThr , verbose = verbose ,
337
+ log = log , warn = warn ,
332
338
** kwargs )
333
339
else :
334
340
raise ValueError ("Unknown method '%s'." % method )
@@ -341,15 +347,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
341
347
342
348
if method .lower () == 'sinkhorn' :
343
349
return sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
344
- stopThr = stopThr , verbose = verbose , log = log ,
350
+ stopThr = stopThr , verbose = verbose ,
351
+ log = log , warn = warn ,
345
352
** kwargs )
346
353
elif method .lower () == 'sinkhorn_log' :
347
354
return sinkhorn_log (a , b , M , reg , numItermax = numItermax ,
348
- stopThr = stopThr , verbose = verbose , log = log ,
355
+ stopThr = stopThr , verbose = verbose ,
356
+ log = log , warn = warn ,
349
357
** kwargs )
350
358
elif method .lower () == 'sinkhorn_stabilized' :
351
359
return sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
352
- stopThr = stopThr , verbose = verbose , log = log ,
360
+ stopThr = stopThr , verbose = verbose ,
361
+ log = log , warn = warn ,
353
362
** kwargs )
354
363
else :
355
364
raise ValueError ("Unknown method '%s'." % method )
@@ -1278,7 +1287,7 @@ def get_reg(n): # exponential decreasing
1278
1287
regi = get_reg (ii )
1279
1288
1280
1289
G , logi = sinkhorn_stabilized (a , b , M , regi ,
1281
- numItermax = numInnerItermax , stopThr = 1e-9 ,
1290
+ numItermax = numInnerItermax , stopThr = stopThr ,
1282
1291
warmstart = (alpha , beta ), verbose = False ,
1283
1292
print_period = 20 , tau = tau , log = True )
1284
1293
@@ -3059,6 +3068,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
3059
3068
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
3060
3069
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
3061
3070
3071
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
3072
+ the entropic contribution).
3073
+
3062
3074
3063
3075
Parameters
3064
3076
----------
@@ -3237,6 +3249,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3237
3249
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
3238
3250
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
3239
3251
3252
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F -(\langle \gamma^*_a, \mathbf{M_a} \rangle_F + \langle
3253
+ \gamma^*_b , \mathbf{M_b} \rangle_F)/2`.
3254
+
3255
+ .. note: The current implementation does not account for the entropic contributions and thus differs from the
3256
+ Sinkhorn divergence as introduced in the literature. The possibility to account for the entropic contributions
3257
+ will be provided in a future release.
3258
+
3240
3259
3241
3260
Parameters
3242
3261
----------
@@ -3293,17 +3312,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3293
3312
if log :
3294
3313
sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
3295
3314
numIterMax = numIterMax ,
3296
- stopThr = 1e-9 , verbose = verbose ,
3315
+ stopThr = stopThr , verbose = verbose ,
3297
3316
log = log , warn = warn , ** kwargs )
3298
3317
3299
3318
sinkhorn_loss_a , log_a = empirical_sinkhorn2 (X_s , X_s , reg , a , a , metric = metric ,
3300
3319
numIterMax = numIterMax ,
3301
- stopThr = 1e-9 , verbose = verbose ,
3320
+ stopThr = stopThr , verbose = verbose ,
3302
3321
log = log , warn = warn , ** kwargs )
3303
3322
3304
3323
sinkhorn_loss_b , log_b = empirical_sinkhorn2 (X_t , X_t , reg , b , b , metric = metric ,
3305
3324
numIterMax = numIterMax ,
3306
- stopThr = 1e-9 , verbose = verbose ,
3325
+ stopThr = stopThr , verbose = verbose ,
3307
3326
log = log , warn = warn , ** kwargs )
3308
3327
3309
3328
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b )
@@ -3320,17 +3339,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3320
3339
3321
3340
else :
3322
3341
sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
3323
- numIterMax = numIterMax , stopThr = 1e-9 ,
3342
+ numIterMax = numIterMax , stopThr = stopThr ,
3324
3343
verbose = verbose , log = log ,
3325
3344
warn = warn , ** kwargs )
3326
3345
3327
3346
sinkhorn_loss_a = empirical_sinkhorn2 (X_s , X_s , reg , a , a , metric = metric ,
3328
- numIterMax = numIterMax , stopThr = 1e-9 ,
3347
+ numIterMax = numIterMax , stopThr = stopThr ,
3329
3348
verbose = verbose , log = log ,
3330
3349
warn = warn , ** kwargs )
3331
3350
3332
3351
sinkhorn_loss_b = empirical_sinkhorn2 (X_t , X_t , reg , b , b , metric = metric ,
3333
- numIterMax = numIterMax , stopThr = 1e-9 ,
3352
+ numIterMax = numIterMax , stopThr = stopThr ,
3334
3353
verbose = verbose , log = log ,
3335
3354
warn = warn , ** kwargs )
3336
3355
0 commit comments