@@ -944,16 +944,17 @@ def eigh(self, a):
944
944
"""
945
945
raise NotImplementedError ()
946
946
947
- def kl_div (self , p , q , eps = 1e-16 ):
947
+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
948
948
r"""
949
- Computes the Kullback-Leibler divergence.
949
+ Computes the (Generalized) Kullback-Leibler divergence.
950
950
951
951
This function follows the api from :any:`scipy.stats.entropy`.
952
952
953
953
Parameter eps is used to avoid numerical errors and is added in the log.
954
954
955
955
.. math::
956
- KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
956
+ KL(p,q) = \langle \mathbf{p}, log(\mathbf{p} / \mathbf{q} + eps \rangle
957
+ + \mathbb{1}_{mass=True} \langle \mathbf{q} - \mathbf{p}, \mathbf{1} \rangle
957
958
958
959
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
959
960
"""
@@ -1352,8 +1353,11 @@ def sqrtm(self, a):
1352
1353
def eigh (self , a ):
1353
1354
return np .linalg .eigh (a )
1354
1355
1355
- def kl_div (self , p , q , eps = 1e-16 ):
1356
- return np .sum (p * np .log (p / q + eps ))
1356
+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
1357
+ value = np .sum (p * np .log (p / q + eps ))
1358
+ if mass :
1359
+ value = value + np .sum (q - p )
1360
+ return value
1357
1361
1358
1362
def isfinite (self , a ):
1359
1363
return np .isfinite (a )
@@ -1751,8 +1755,11 @@ def sqrtm(self, a):
1751
1755
def eigh (self , a ):
1752
1756
return jnp .linalg .eigh (a )
1753
1757
1754
- def kl_div (self , p , q , eps = 1e-16 ):
1755
- return jnp .sum (p * jnp .log (p / q + eps ))
1758
+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
1759
+ value = jnp .sum (p * jnp .log (p / q + eps ))
1760
+ if mass :
1761
+ value = value + jnp .sum (q - p )
1762
+ return value
1756
1763
1757
1764
def isfinite (self , a ):
1758
1765
return jnp .isfinite (a )
@@ -2238,8 +2245,11 @@ def sqrtm(self, a):
2238
2245
def eigh (self , a ):
2239
2246
return torch .linalg .eigh (a )
2240
2247
2241
- def kl_div (self , p , q , eps = 1e-16 ):
2242
- return torch .sum (p * torch .log (p / q + eps ))
2248
+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
2249
+ value = torch .sum (p * torch .log (p / q + eps ))
2250
+ if mass :
2251
+ value = value + torch .sum (q - p )
2252
+ return value
2243
2253
2244
2254
def isfinite (self , a ):
2245
2255
return torch .isfinite (a )
@@ -2639,8 +2649,11 @@ def sqrtm(self, a):
2639
2649
def eigh (self , a ):
2640
2650
return cp .linalg .eigh (a )
2641
2651
2642
- def kl_div (self , p , q , eps = 1e-16 ):
2643
- return cp .sum (p * cp .log (p / q + eps ))
2652
+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
2653
+ value = cp .sum (p * cp .log (p / q + eps ))
2654
+ if mass :
2655
+ value = value + cp .sum (q - p )
2656
+ return value
2644
2657
2645
2658
def isfinite (self , a ):
2646
2659
return cp .isfinite (a )
@@ -3063,8 +3076,11 @@ def sqrtm(self, a):
3063
3076
def eigh (self , a ):
3064
3077
return tf .linalg .eigh (a )
3065
3078
3066
- def kl_div (self , p , q , eps = 1e-16 ):
3067
- return tnp .sum (p * tnp .log (p / q + eps ))
3079
+ def kl_div (self , p , q , mass = False , eps = 1e-16 ):
3080
+ value = tnp .sum (p * tnp .log (p / q + eps ))
3081
+ if mass :
3082
+ value = value + tnp .sum (q - p )
3083
+ return value
3068
3084
3069
3085
def isfinite (self , a ):
3070
3086
return tnp .isfinite (a )
0 commit comments