Skip to content

Commit c2b89e2

Browse files
authored
Merge pull request #419 from ComputationalCryoEM/add_wbemd_418
Add wbemd 418
2 parents 9314944 + 5d4ac32 commit c2b89e2

File tree

4 files changed

+141
-0
lines changed

4 files changed

+141
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def read(fname):
3232
"numpy==1.16",
3333
"pandas==0.25.3",
3434
"pyfftw",
35+
"PyWavelets",
3536
"pillow",
3637
"scipy==1.4.0",
3738
"scikit-learn",

src/aspire/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
ZeroFilter,
1616
voltage_to_wavelength,
1717
)
18+
from .wemd import wemd_embed, wemd_norm

src/aspire/operators/wemd.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
Wavelet-based approximate Earthmover's distance (EMD) for n-dimensional signals.
3+
4+
This code is based on the following paper:
5+
Sameer Shirdhonkar and David W. Jacobs.
6+
"Approximate earth mover’s distance in linear time."
7+
2008 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
8+
9+
More details are available in their technical report:
10+
CAR-TR-1025 CS-TR-4908 UMIACS-TR-2008-06.
11+
"""
12+
13+
import warnings
14+
15+
import numpy as np
16+
import pywt
17+
18+
19+
def wemd_embed(arr, wavelet="coif3", level=None):
20+
"""
21+
This function computes an embedding of Numpy arrays such that
22+
for non-negative arrays that sum to one, the L1 distance between the
23+
resulting embeddings is strongly equivalent to the Earthmover distance
24+
of the arrays.
25+
26+
:param arr: Numpy array
27+
:param level: Decomposition level of the wavelets.
28+
Larger levels yield more coefficients and more accurate results.
29+
If no level is given, we take the the log2 of the side-length of the domain.
30+
:param wavelet: Either the name of a wavelet supported by PyWavelets
31+
(e.g. 'coif3', 'sym3', 'sym5', etc.) or a pywt.Wavelet object
32+
See https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html#built-in-wavelets-wavelist
33+
The default is 'coif3', because it seems to work well empirically.
34+
:returns: One-dimensional numpy array containing weighted details coefficients.
35+
"""
36+
dimension = arr.ndim
37+
38+
if level is None:
39+
level = int(np.ceil(np.log2(max(arr.shape)))) + 1
40+
41+
# Using wavedecn with the default level creates this boundary effects warning.
42+
# However, this doesn't seem to be a cause for concern.
43+
with warnings.catch_warnings():
44+
warnings.filterwarnings(
45+
"ignore",
46+
message="Level value of .* is too high:"
47+
" all coefficients will experience boundary effects.",
48+
)
49+
arrdwt = pywt.wavedecn(arr, wavelet, mode="zero", level=level)
50+
51+
detail_coefs = arrdwt[1:]
52+
assert len(detail_coefs) == level
53+
54+
weighted_coefs = []
55+
for (j, details_level_j) in enumerate(detail_coefs):
56+
multiplier = 2 ** ((level - 1 - j) * (1 + (dimension / 2.0)))
57+
for coefs in details_level_j.values():
58+
weighted_coefs.append(multiplier * coefs.flatten())
59+
60+
return np.concatenate(weighted_coefs)
61+
62+
63+
def wemd_norm(arr, wavelet="coif3", level=None):
64+
"""
65+
Wavelet-based norm used to approximate the Earthmover's distance between
66+
mass distributions specified as Numpy arrays (typically images or volumes).
67+
68+
:param arr: Numpy array of the difference between the two mass distributions.
69+
:param level: Decomposition level of the wavelets.
70+
Larger levels yield more coefficients and more accurate results.
71+
If no level is given, we take the the log2 of the side-length of the domain.
72+
Larger levels yield more coefficients and more accurate results
73+
:param wavelet: Either the name of a wavelet supported by PyWavelets
74+
(e.g. 'coif3', 'sym3', 'sym5', etc.) or a pywt.Wavelet object
75+
See https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html#built-in-wavelets-wavelist
76+
The default is 'coif3', because it seems to work well empirically.
77+
:return: Approximated Earthmover's Distance
78+
"""
79+
80+
coefs = wemd_embed(arr, wavelet, level)
81+
return np.linalg.norm(coefs, ord=1)

tests/test_wbemd.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import logging
2+
from unittest import TestCase
3+
4+
import numpy as np
5+
from numpy import asarray, cos, mgrid, sin
6+
7+
from aspire.operators import wemd_norm
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def _smoothed_disk_image(x, y, radius, width, height):
13+
(Y, X) = mgrid[:height, :width]
14+
ratio = ((X - x) ** 2 + (Y - y) ** 2) / (radius ** 2)
15+
return 2.0 - 2 / (1 + np.exp(-ratio)) # Scaled sigmoid funciton
16+
17+
18+
def _is_monotone(seq):
19+
arr = asarray(seq)
20+
assert arr.ndim == 1
21+
return np.all(arr[1:] >= arr[:-1])
22+
23+
24+
class WEMDTestCase(TestCase):
25+
"""
26+
Test that the WEMD distance between smoothed disks of various radii,
27+
angles and distances is monotone in the Euclidean distance of their centers.
28+
Note that this monotonicity isn't strictly required by the theory,
29+
but holds empirically.
30+
"""
31+
32+
def test_wemd_norm(self):
33+
WIDTH = 64
34+
HEIGHT = 64
35+
CENTER_X = WIDTH // 2
36+
CENTER_Y = HEIGHT // 2
37+
38+
# A few disk radii and ray angles to test
39+
RADII = [1, 2, 3, 4, 5, 6, 7]
40+
ANGLES = [0.0, 0.4755, 0.6538, 1.9818, 3.0991, 4.4689, 4.9859, 5.5752]
41+
42+
for radius in RADII:
43+
for angle in ANGLES:
44+
disks = [
45+
_smoothed_disk_image(
46+
CENTER_X + int(k * cos(angle)),
47+
CENTER_Y + int(k * sin(angle)),
48+
radius,
49+
WIDTH,
50+
HEIGHT,
51+
)
52+
for k in range(0, 16, 2)
53+
]
54+
wemd_distances_along_ray = [
55+
wemd_norm(disks[0] - disk) for disk in disks
56+
]
57+
logger.info(f"wemd distances along ray: {wemd_distances_along_ray}")
58+
self.assertTrue(_is_monotone(wemd_distances_along_ray))

0 commit comments

Comments
 (0)