Skip to content

Commit 24ad25c

Browse files
[WIP] add mass feature to nx.kl_div and harmonize kl computation in the toolbox (#654)
* add mass feature to nx.kl_div * test * test * fix tipo doc * fix jax
1 parent e530985 commit 24ad25c

File tree

6 files changed

+54
-25
lines changed

6 files changed

+54
-25
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## 0.9.5dev
44

55
#### New features
6+
- Add feature `mass=True` for `nx.kl_div` (PR #654)
67

78
#### Closed issues
89

ot/backend.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -944,16 +944,17 @@ def eigh(self, a):
944944
"""
945945
raise NotImplementedError()
946946

947-
def kl_div(self, p, q, eps=1e-16):
947+
def kl_div(self, p, q, mass=False, eps=1e-16):
948948
r"""
949-
Computes the Kullback-Leibler divergence.
949+
Computes the (Generalized) Kullback-Leibler divergence.
950950
951951
This function follows the api from :any:`scipy.stats.entropy`.
952952
953953
Parameter eps is used to avoid numerical errors and is added in the log.
954954
955955
.. math::
956-
KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
956+
KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle
957+
+ \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle
957958
958959
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
959960
"""
@@ -1352,8 +1353,11 @@ def sqrtm(self, a):
13521353
def eigh(self, a):
13531354
return np.linalg.eigh(a)
13541355

1355-
def kl_div(self, p, q, eps=1e-16):
1356-
return np.sum(p * np.log(p / q + eps))
1356+
def kl_div(self, p, q, mass=False, eps=1e-16):
1357+
value = np.sum(p * np.log(p / q + eps))
1358+
if mass:
1359+
value = value + np.sum(q - p)
1360+
return value
13571361

13581362
def isfinite(self, a):
13591363
return np.isfinite(a)
@@ -1751,8 +1755,11 @@ def sqrtm(self, a):
17511755
def eigh(self, a):
17521756
return jnp.linalg.eigh(a)
17531757

1754-
def kl_div(self, p, q, eps=1e-16):
1755-
return jnp.sum(p * jnp.log(p / q + eps))
1758+
def kl_div(self, p, q, mass=False, eps=1e-16):
1759+
value = jnp.sum(p * jnp.log(p / q + eps))
1760+
if mass:
1761+
value = value + jnp.sum(q - p)
1762+
return value
17561763

17571764
def isfinite(self, a):
17581765
return jnp.isfinite(a)
@@ -2238,8 +2245,11 @@ def sqrtm(self, a):
22382245
def eigh(self, a):
22392246
return torch.linalg.eigh(a)
22402247

2241-
def kl_div(self, p, q, eps=1e-16):
2242-
return torch.sum(p * torch.log(p / q + eps))
2248+
def kl_div(self, p, q, mass=False, eps=1e-16):
2249+
value = torch.sum(p * torch.log(p / q + eps))
2250+
if mass:
2251+
value = value + torch.sum(q - p)
2252+
return value
22432253

22442254
def isfinite(self, a):
22452255
return torch.isfinite(a)
@@ -2639,8 +2649,11 @@ def sqrtm(self, a):
26392649
def eigh(self, a):
26402650
return cp.linalg.eigh(a)
26412651

2642-
def kl_div(self, p, q, eps=1e-16):
2643-
return cp.sum(p * cp.log(p / q + eps))
2652+
def kl_div(self, p, q, mass=False, eps=1e-16):
2653+
value = cp.sum(p * cp.log(p / q + eps))
2654+
if mass:
2655+
value = value + cp.sum(q - p)
2656+
return value
26442657

26452658
def isfinite(self, a):
26462659
return cp.isfinite(a)
@@ -3063,8 +3076,11 @@ def sqrtm(self, a):
30633076
def eigh(self, a):
30643077
return tf.linalg.eigh(a)
30653078

3066-
def kl_div(self, p, q, eps=1e-16):
3067-
return tnp.sum(p * tnp.log(p / q + eps))
3079+
def kl_div(self, p, q, mass=False, eps=1e-16):
3080+
value = tnp.sum(p * tnp.log(p / q + eps))
3081+
if mass:
3082+
value = value + tnp.sum(q - p)
3083+
return value
30683084

30693085
def isfinite(self, a):
30703086
return tnp.isfinite(a)

ot/bregman/_barycenter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
364364
log = {'err': []}
365365

366366
M = - M / reg
367-
logA = nx.log(A + 1e-15)
367+
logA = nx.log(A + 1e-16)
368368
log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
369369
err = 1
370370
for ii in range(numItermax):
@@ -702,7 +702,7 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
702702
log = {'err': []}
703703

704704
M = - M / reg
705-
logA = nx.log(A + 1e-15)
705+
logA = nx.log(A + 1e-16)
706706
log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
707707
c = nx.zeros(dim, type_as=A)
708708
err = 1

ot/coot.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,6 @@ def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat
139139
Advances in Neural Information Processing ny_sampstems, 33 (2020).
140140
"""
141141

142-
def compute_kl(p, q):
143-
kl = nx.sum(p * nx.log(p + 1.0 * (p == 0))) - nx.sum(p * nx.log(q))
144-
return kl
145-
146142
# Main function
147143

148144
if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]:
@@ -245,9 +241,9 @@ def compute_kl(p, q):
245241
coot = coot + alpha_samp * nx.sum(M_samp * pi_samp)
246242
# Entropic part
247243
if eps_samp != 0:
248-
coot = coot + eps_samp * compute_kl(pi_samp, wxy_samp)
244+
coot = coot + eps_samp * nx.kl_div(pi_samp, wxy_samp)
249245
if eps_feat != 0:
250-
coot = coot + eps_feat * compute_kl(pi_feat, wxy_feat)
246+
coot = coot + eps_feat * nx.kl_div(pi_feat, wxy_feat)
251247
list_coot.append(coot)
252248

253249
if err < tol_bcd or abs(list_coot[-2] - list_coot[-1]) < early_stopping_tol:

ot/gromov/_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def h2(b):
109109
return 2 * b
110110
elif loss_fun == 'kl_loss':
111111
def f1(a):
112-
return a * nx.log(a + 1e-15) - a
112+
return a * nx.log(a + 1e-16) - a
113113

114114
def f2(b):
115115
return b
@@ -118,7 +118,7 @@ def h1(a):
118118
return a
119119

120120
def h2(b):
121-
return nx.log(b + 1e-15)
121+
return nx.log(b + 1e-16)
122122
else:
123123
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
124124

@@ -502,7 +502,7 @@ def h2(b):
502502
return 2 * b
503503
elif loss_fun == 'kl_loss':
504504
def f1(a):
505-
return a * nx.log(a + 1e-15) - a
505+
return a * nx.log(a + 1e-16) - a
506506

507507
def f2(b):
508508
return b
@@ -511,7 +511,7 @@ def h1(a):
511511
return a
512512

513513
def h2(b):
514-
return nx.log(b + 1e-15)
514+
return nx.log(b + 1e-16)
515515
else:
516516
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")
517517

test/test_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,3 +620,19 @@ def test_label_normalization(nx):
620620
# labels are shifted but the shift if expected
621621
y_normalized_start = ot.utils.label_normalization(y, start=1)
622622
np.testing.assert_array_equal(y, y_normalized_start)
623+
624+
625+
def test_kl_div(nx):
626+
n = 10
627+
rng = np.random.RandomState(0)
628+
# test on non-negative tensors
629+
x = rng.randn(n)
630+
x = x - x.min() + 1e-5
631+
y = rng.randn(n)
632+
y = y - y.min() + 1e-5
633+
xb = nx.from_numpy(x)
634+
yb = nx.from_numpy(y)
635+
kl = nx.kl_div(xb, yb)
636+
kl_mass = nx.kl_div(xb, yb, True)
637+
recovered_kl = kl_mass - nx.sum(yb - xb)
638+
np.testing.assert_allclose(kl, recovered_kl)

0 commit comments

Comments
 (0)