-
Notifications
You must be signed in to change notification settings - Fork 26
Add wbemd 418 #419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add wbemd 418 #419
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
19eb2b5
Add PyWavelets dependency
garrettwrong 64882e4
initial add of Amit M code
garrettwrong edb5224
add wemd function and change docstring format
garrettwrong d94d702
formatting wemd code
garrettwrong 96a2c74
add wemd to operators __init__ and change name embed~>wembed
garrettwrong 324dd04
change assert to ValueError
garrettwrong ca6d1e8
add wemd unit test stub
garrettwrong 016a0bb
First version of wavelet-based approximate Earthmover's distance, wit…
mosco 723fb8e
apply auto style/linter
garrettwrong e3fe655
minor tweaks, unused import, int div, conver print to logger
garrettwrong 05bcc10
line length tweaks
garrettwrong e279c3d
Reponse to janden's review.
mosco 0c02981
Minor comment
mosco aea4d7b
Added default values for the wavelet and level parameters of WEMD.
mosco 29982cf
Minor performance improvement of wemd_embed
mosco 5ef2679
wemd_embed: suppress boundary effect warning when calling pywt.wavede…
mosco 6cec52a
isort/black formatting
garrettwrong 8e5de52
remove unused imports
garrettwrong 37d39dd
Disambiguate built-in `all` with `np.all`
garrettwrong 57512e7
use context block for warnings filter
garrettwrong 5d4ac32
rename testfile to match methods
garrettwrong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,3 +14,4 @@ | |
| ScaledFilter, | ||
| ZeroFilter, | ||
| ) | ||
| from .wemd import wemd_embed, wemd_norm | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| """ | ||
| Wavelet-based approximate Earthmover's distance (EMD) for n-dimensional signals. | ||
|
|
||
| This code is based on the following paper: | ||
| Sameer Shirdhonkar and David W. Jacobs. | ||
| "Approximate earth mover’s distance in linear time." | ||
| 2008 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). | ||
|
|
||
| More details are available in their technical report: | ||
| CAR-TR-1025 CS-TR-4908 UMIACS-TR-2008-06. | ||
| """ | ||
|
|
||
| import warnings | ||
|
|
||
| import numpy as np | ||
| import pywt | ||
|
|
||
|
|
||
| def wemd_embed(arr, wavelet="coif3", level=None): | ||
| """ | ||
| This function computes an embedding of Numpy arrays such that | ||
| for non-negative arrays that sum to one, the L1 distance between the | ||
| resulting embeddings is strongly equivalent to the Earthmover distance | ||
| of the arrays. | ||
|
|
||
| :param arr: Numpy array | ||
| :param level: Decomposition level of the wavelets. | ||
| Larger levels yield more coefficients and more accurate results. | ||
| If no level is given, we take the the log2 of the side-length of the domain. | ||
| :param wavelet: Either the name of a wavelet supported by PyWavelets | ||
| (e.g. 'coif3', 'sym3', 'sym5', etc.) or a pywt.Wavelet object | ||
| See https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html#built-in-wavelets-wavelist | ||
| The default is 'coif3', because it seems to work well empirically. | ||
| :returns: One-dimensional numpy array containing weighted details coefficients. | ||
| """ | ||
| dimension = arr.ndim | ||
|
|
||
| if level is None: | ||
| level = int(np.ceil(np.log2(max(arr.shape)))) + 1 | ||
|
|
||
| # Using wavedecn with the default level creates this boundary effects warning. | ||
| # However, this doesn't seem to be a cause for concern. | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings( | ||
| "ignore", | ||
| message="Level value of .* is too high:" | ||
| " all coefficients will experience boundary effects.", | ||
| ) | ||
| arrdwt = pywt.wavedecn(arr, wavelet, mode="zero", level=level) | ||
|
|
||
| detail_coefs = arrdwt[1:] | ||
| assert len(detail_coefs) == level | ||
|
|
||
| weighted_coefs = [] | ||
| for (j, details_level_j) in enumerate(detail_coefs): | ||
| multiplier = 2 ** ((level - 1 - j) * (1 + (dimension / 2.0))) | ||
| for coefs in details_level_j.values(): | ||
| weighted_coefs.append(multiplier * coefs.flatten()) | ||
|
|
||
| return np.concatenate(weighted_coefs) | ||
|
|
||
|
|
||
| def wemd_norm(arr, wavelet="coif3", level=None): | ||
| """ | ||
| Wavelet-based norm used to approximate the Earthmover's distance between | ||
| mass distributions specified as Numpy arrays (typically images or volumes). | ||
|
|
||
| :param arr: Numpy array of the difference between the two mass distributions. | ||
| :param level: Decomposition level of the wavelets. | ||
| Larger levels yield more coefficients and more accurate results. | ||
| If no level is given, we take the the log2 of the side-length of the domain. | ||
| Larger levels yield more coefficients and more accurate results | ||
| :param wavelet: Either the name of a wavelet supported by PyWavelets | ||
| (e.g. 'coif3', 'sym3', 'sym5', etc.) or a pywt.Wavelet object | ||
| See https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html#built-in-wavelets-wavelist | ||
| The default is 'coif3', because it seems to work well empirically. | ||
| :return: Approximated Earthmover's Distance | ||
| """ | ||
|
|
||
| coefs = wemd_embed(arr, wavelet, level) | ||
| return np.linalg.norm(coefs, ord=1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import logging | ||
| from unittest import TestCase | ||
|
|
||
| import numpy as np | ||
| from numpy import asarray, cos, mgrid, sin | ||
|
|
||
| from aspire.operators import wemd_norm | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _smoothed_disk_image(x, y, radius, width, height): | ||
| (Y, X) = mgrid[:height, :width] | ||
| ratio = ((X - x) ** 2 + (Y - y) ** 2) / (radius ** 2) | ||
| return 2.0 - 2 / (1 + np.exp(-ratio)) # Scaled sigmoid funciton | ||
|
|
||
|
|
||
| def _is_monotone(seq): | ||
| arr = asarray(seq) | ||
| assert arr.ndim == 1 | ||
| return np.all(arr[1:] >= arr[:-1]) | ||
|
|
||
|
|
||
| class WEMDTestCase(TestCase): | ||
| """ | ||
| Test that the WEMD distance between smoothed disks of various radii, | ||
| angles and distances is monotone in the Euclidean distance of their centers. | ||
| Note that this monotonicity isn't strictly required by the theory, | ||
| but holds empirically. | ||
| """ | ||
|
|
||
| def test_wemd_norm(self): | ||
| WIDTH = 64 | ||
| HEIGHT = 64 | ||
| CENTER_X = WIDTH // 2 | ||
| CENTER_Y = HEIGHT // 2 | ||
|
|
||
| # A few disk radii and ray angles to test | ||
| RADII = [1, 2, 3, 4, 5, 6, 7] | ||
| ANGLES = [0.0, 0.4755, 0.6538, 1.9818, 3.0991, 4.4689, 4.9859, 5.5752] | ||
|
|
||
| for radius in RADII: | ||
| for angle in ANGLES: | ||
| disks = [ | ||
| _smoothed_disk_image( | ||
| CENTER_X + int(k * cos(angle)), | ||
| CENTER_Y + int(k * sin(angle)), | ||
| radius, | ||
| WIDTH, | ||
| HEIGHT, | ||
| ) | ||
| for k in range(0, 16, 2) | ||
| ] | ||
| wemd_distances_along_ray = [ | ||
| wemd_norm(disks[0] - disk) for disk in disks | ||
| ] | ||
| logger.info(f"wemd distances along ray: {wemd_distances_along_ray}") | ||
| self.assertTrue(_is_monotone(wemd_distances_along_ray)) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.