|
11 | 11 |
|
12 | 12 | import warnings
|
13 | 13 |
|
14 |
| -from ..utils import dist, list_to_array, unif |
| 14 | +from ..utils import dist, list_to_array, unif, LazyTensor |
15 | 15 | from ..backend import get_backend
|
16 | 16 |
|
17 | 17 | from ._sinkhorn import sinkhorn, sinkhorn2
|
18 | 18 |
|
19 | 19 |
|
| 20 | +def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric='sqeuclidean', reg=1e-1, nx=None): |
| 21 | + r""" Get a LazyTensor of Sinkhorn solution from the dual potentials |
| 22 | +
|
| 23 | + The returned LazyTensor is |
| 24 | + :math:`\mathbf{T} = exp( \mathbf{f} \mathbf{1}_b^\top + \mathbf{1}_a \mathbf{g}^\top - \mathbf{C}/reg)`, where :math:`\mathbf{C}` is the pairwise metric matrix between samples :math:`\mathbf{X}_a` and :math:`\mathbf{X}_b`. |
| 25 | +
|
| 26 | + Parameters |
| 27 | + ---------- |
| 28 | + X_a : array-like, shape (n_samples_a, dim) |
| 29 | + samples in the source domain |
| 30 | + X_b : array-like, shape (n_samples_b, dim) |
| 31 | + samples in the target domain |
| 32 | + f : array-like, shape (n_samples_a,) |
| 33 | + First dual potentials (log space) |
| 34 | + g : array-like, shape (n_samples_b,) |
| 35 | + Second dual potentials (log space) |
| 36 | + metric : str, default='sqeuclidean' |
| 37 | + Metric used for the cost matrix computation |
| 38 | + reg : float, default=1e-1 |
| 39 | + Regularization term >0 |
| 40 | + nx : Backend(), default=None |
| 41 | + Numerical backend used |
| 42 | +
|
| 43 | +
|
| 44 | + Returns |
| 45 | + ------- |
| 46 | + T : LazyTensor |
| 47 | + Sinkhorn solution tensor |
| 48 | + """ |
| 49 | + |
| 50 | + if nx is None: |
| 51 | + nx = get_backend(X_a, X_b, f, g) |
| 52 | + |
| 53 | + shape = (X_a.shape[0], X_b.shape[0]) |
| 54 | + |
| 55 | + def func(i, j, X_a, X_b, f, g, metric, reg): |
| 56 | + C = dist(X_a[i], X_b[j], metric=metric) |
| 57 | + return nx.exp(f[i, None] + g[None, j] - C / reg) |
| 58 | + |
| 59 | + T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, metric=metric, reg=reg) |
| 60 | + |
| 61 | + return T |
| 62 | + |
| 63 | + |
20 | 64 | def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
|
21 | 65 | numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
|
22 | 66 | log=False, warn=True, warmstart=None, **kwargs):
|
@@ -198,6 +242,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
|
198 | 242 | if log:
|
199 | 243 | dict_log["u"] = f
|
200 | 244 | dict_log["v"] = g
|
| 245 | + dict_log["niter"] = i_ot |
| 246 | + dict_log["lazy_plan"] = get_sinkhorn_lazytensor(X_s, X_t, f, g, metric, reg) |
201 | 247 | return (f, g, dict_log)
|
202 | 248 | else:
|
203 | 249 | return (f, g)
|
|
0 commit comments