-
Notifications
You must be signed in to change notification settings - Fork 529
[MRG] add the sparsity-constrained optimal transport funtionality and example #459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
rflamary
merged 14 commits into
PythonOT:master
from
liutianlin0121:sparsity_constrained
Apr 25, 2023
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
eeaca57
add sparsity-constrained ot funtionality and example
liutianlin0121 e8bb4e0
correct typos; add projection_sparse_simplex
liutianlin0121 cbebf62
Merge branch 'master' into sparsity_constrained
rflamary 99daa66
add gradcheck; merge ot.sparse into ot.smooth.
liutianlin0121 1afb005
Merge branch 'sparsity_constrained' of https://github.com/liutianlin0…
liutianlin0121 ba16167
Merge branch 'master' into sparsity_constrained
rflamary bd34757
reuse existing ot.smooth functions with a new 'sparsity_constrained' …
liutianlin0121 d734fbb
address pep8 error
liutianlin0121 a96ff77
Merge branch 'master' into sparsity_constrained
rflamary 5484ed0
add backends for
liutianlin0121 3fb22c2
Merge branch 'sparsity_constrained' of https://github.com/liutianlin0…
liutianlin0121 6107526
Merge branch 'master' into sparsity_constrained
rflamary 9206c00
update releases
liutianlin0121 11e07aa
Merge branch 'sparsity_constrained' of https://github.com/liutianlin0…
liutianlin0121 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,27 +24,42 @@ | |
|
||
# Author: Mathieu Blondel | ||
# Remi Flamary <[email protected]> | ||
liutianlin0121 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Tianlin Liu <[email protected]> | ||
|
||
""" | ||
Smooth and Sparse Optimal Transport solvers (KL an L2 reg.) | ||
Smooth and Sparse (KL an L2 reg.) and sparsity-constrained OT solvers. | ||
|
||
Implementation of : | ||
Smooth and Sparse Optimal Transport. | ||
Mathieu Blondel, Vivien Seguy, Antoine Rolet. | ||
In Proc. of AISTATS 2018. | ||
https://arxiv.org/abs/1710.06276 | ||
|
||
(Original code from https://github.com/mblondel/smooth-ot/) | ||
|
||
Sparsity-Constrained Optimal Transport. | ||
Liu, T., Puigcerver, J., & Blondel, M. (2023). | ||
Sparsity-constrained optimal transport. | ||
Proceedings of the Eleventh International Conference on | ||
Learning Representations (ICLR). | ||
https://arxiv.org/abs/2209.15466 | ||
|
||
|
||
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal | ||
Transport. Proceedings of the Twenty-First International Conference on | ||
Artificial Intelligence and Statistics (AISTATS). | ||
|
||
Original code from https://github.com/mblondel/smooth-ot/ | ||
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). | ||
Sparsity-constrained optimal transport. | ||
Proceedings of the Eleventh International Conference on | ||
Learning Representations (ICLR). | ||
|
||
""" | ||
|
||
import numpy as np | ||
from scipy.optimize import minimize | ||
from .backend import get_backend | ||
import ot | ||
|
||
|
||
def projection_simplex(V, z=1, axis=None): | ||
|
@@ -209,6 +224,39 @@ def Omega(self, T): | |
return 0.5 * self.gamma * np.sum(T ** 2) | ||
|
||
|
||
class SparsityConstrained(Regularization): | ||
""" Squared L2 regularization with sparsity constraints """ | ||
|
||
def __init__(self, max_nz, gamma=1.0): | ||
self.max_nz = max_nz | ||
self.gamma = gamma | ||
|
||
def delta_Omega(self, X): | ||
# For each column of X, find entries that are not among the top max_nz. | ||
non_top_indices = np.argpartition( | ||
-X, self.max_nz, axis=0)[self.max_nz:] | ||
# Set these entries to -inf. | ||
if X.ndim == 1: | ||
X[non_top_indices] = 0.0 | ||
else: | ||
X[non_top_indices, np.arange(X.shape[1])] = 0.0 | ||
max_X = np.maximum(X, 0) | ||
val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma) | ||
G = max_X / self.gamma | ||
return val, G | ||
|
||
def max_Omega(self, X, b): | ||
# Project the scaled X onto the simplex with sparsity constraint. | ||
G = ot.utils.projection_sparse_simplex( | ||
X / (b * self.gamma), self.max_nz, axis=0) | ||
val = np.sum(X * G, axis=0) | ||
val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0) | ||
return val, G | ||
|
||
def Omega(self, T): | ||
return 0.5 * self.gamma * np.sum(T ** 2) | ||
|
||
|
||
def dual_obj_grad(alpha, beta, a, b, C, regul): | ||
r""" | ||
Compute objective value and gradients of dual objective. | ||
|
@@ -435,8 +483,9 @@ def get_plan_from_semi_dual(alpha, b, C, regul): | |
return regul.max_Omega(X, b)[1] * b | ||
|
||
|
||
def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, | ||
numItermax=500, verbose=False, log=False): | ||
def smooth_ot_dual(a, b, M, reg, reg_type='l2', | ||
method="L-BFGS-B", stopThr=1e-9, | ||
numItermax=500, verbose=False, log=False, max_nz=None): | ||
r""" | ||
Solve the regularized OT problem in the dual and return the OT matrix | ||
|
||
|
@@ -477,6 +526,9 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, | |
:ref:`[2] <references-smooth-ot-dual>`) | ||
|
||
- 'l2' : Squared Euclidean regularization | ||
- 'sparsity_constrained' : Sparsity-constrained regularization [50] | ||
max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan; | ||
not used for other regularization types. | ||
method : str | ||
Solver to use for scipy.optimize.minimize | ||
numItermax : int, optional | ||
|
@@ -504,6 +556,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, | |
|
||
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). | ||
|
||
.. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). | ||
|
||
See Also | ||
-------- | ||
ot.lp.emd : Unregularized OT | ||
|
@@ -518,6 +572,11 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, | |
regul = SquaredL2(gamma=reg) | ||
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: | ||
regul = NegEntropy(gamma=reg) | ||
elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: | ||
if not isinstance(max_nz, int): | ||
raise ValueError( | ||
f'max_nz {max_nz} must be an integer') | ||
regul = SparsityConstrained(gamma=reg, max_nz=max_nz) | ||
else: | ||
raise NotImplementedError('Unknown regularization') | ||
|
||
|
@@ -539,7 +598,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, | |
return G | ||
|
||
|
||
def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, | ||
def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None, | ||
method="L-BFGS-B", stopThr=1e-9, | ||
numItermax=500, verbose=False, log=False): | ||
r""" | ||
Solve the regularized OT problem in the semi-dual and return the OT matrix | ||
|
@@ -583,6 +643,9 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= | |
:ref:`[2] <references-smooth-ot-semi-dual>`) | ||
|
||
- 'l2' : Squared Euclidean regularization | ||
- 'sparsity_constrained' : Sparsity-constrained regularization [50] | ||
max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan; | ||
not used for other regularization types. | ||
method : str | ||
Solver to use for scipy.optimize.minimize | ||
numItermax : int, optional | ||
|
@@ -610,6 +673,8 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= | |
|
||
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). | ||
|
||
.. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). | ||
|
||
See Also | ||
-------- | ||
ot.lp.emd : Unregularized OT | ||
|
@@ -621,6 +686,11 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= | |
regul = SquaredL2(gamma=reg) | ||
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: | ||
regul = NegEntropy(gamma=reg) | ||
elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: | ||
if not isinstance(max_nz, int): | ||
raise ValueError( | ||
f'max_nz {max_nz} must be an integer') | ||
regul = SparsityConstrained(gamma=reg, max_nz=max_nz) | ||
else: | ||
raise NotImplementedError('Unknown regularization') | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.