Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def read(fname):
"numpy==1.16",
"pandas==0.25.3",
"pyfftw",
"PyWavelets",
"pillow",
"scipy==1.4.0",
"scikit-learn",
Expand Down
1 change: 1 addition & 0 deletions src/aspire/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
ScaledFilter,
ZeroFilter,
)
from .wemd import wemd_embed, wemd_norm
81 changes: 81 additions & 0 deletions src/aspire/operators/wemd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
Wavelet-based approximate Earthmover's distance (EMD) for n-dimensional signals.

This code is based on the following paper:
Sameer Shirdhonkar and David W. Jacobs.
"Approximate earth mover’s distance in linear time."
2008 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

More details are available in their technical report:
CAR-TR-1025 CS-TR-4908 UMIACS-TR-2008-06.
"""

import warnings

import numpy as np
import pywt


def wemd_embed(arr, wavelet="coif3", level=None):
"""
This function computes an embedding of Numpy arrays such that
for non-negative arrays that sum to one, the L1 distance between the
resulting embeddings is strongly equivalent to the Earthmover distance
of the arrays.

:param arr: Numpy array
:param level: Decomposition level of the wavelets.
Larger levels yield more coefficients and more accurate results.
If no level is given, we take the the log2 of the side-length of the domain.
:param wavelet: Either the name of a wavelet supported by PyWavelets
(e.g. 'coif3', 'sym3', 'sym5', etc.) or a pywt.Wavelet object
See https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html#built-in-wavelets-wavelist
The default is 'coif3', because it seems to work well empirically.
:returns: One-dimensional numpy array containing weighted details coefficients.
"""
dimension = arr.ndim

if level is None:
level = int(np.ceil(np.log2(max(arr.shape)))) + 1

# Using wavedecn with the default level creates this boundary effects warning.
# However, this doesn't seem to be a cause for concern.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Level value of .* is too high:"
" all coefficients will experience boundary effects.",
)
arrdwt = pywt.wavedecn(arr, wavelet, mode="zero", level=level)

detail_coefs = arrdwt[1:]
assert len(detail_coefs) == level

weighted_coefs = []
for (j, details_level_j) in enumerate(detail_coefs):
multiplier = 2 ** ((level - 1 - j) * (1 + (dimension / 2.0)))
for coefs in details_level_j.values():
weighted_coefs.append(multiplier * coefs.flatten())

return np.concatenate(weighted_coefs)


def wemd_norm(arr, wavelet="coif3", level=None):
"""
Wavelet-based norm used to approximate the Earthmover's distance between
mass distributions specified as Numpy arrays (typically images or volumes).

:param arr: Numpy array of the difference between the two mass distributions.
:param level: Decomposition level of the wavelets.
Larger levels yield more coefficients and more accurate results.
If no level is given, we take the the log2 of the side-length of the domain.
Larger levels yield more coefficients and more accurate results
:param wavelet: Either the name of a wavelet supported by PyWavelets
(e.g. 'coif3', 'sym3', 'sym5', etc.) or a pywt.Wavelet object
See https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html#built-in-wavelets-wavelist
The default is 'coif3', because it seems to work well empirically.
:return: Approximated Earthmover's Distance
"""

coefs = wemd_embed(arr, wavelet, level)
return np.linalg.norm(coefs, ord=1)
58 changes: 58 additions & 0 deletions tests/test_wbemd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
from unittest import TestCase

import numpy as np
from numpy import asarray, cos, mgrid, sin

from aspire.operators import wemd_norm

logger = logging.getLogger(__name__)


def _smoothed_disk_image(x, y, radius, width, height):
(Y, X) = mgrid[:height, :width]
ratio = ((X - x) ** 2 + (Y - y) ** 2) / (radius ** 2)
return 2.0 - 2 / (1 + np.exp(-ratio)) # Scaled sigmoid funciton


def _is_monotone(seq):
arr = asarray(seq)
assert arr.ndim == 1
return np.all(arr[1:] >= arr[:-1])


class WEMDTestCase(TestCase):
"""
Test that the WEMD distance between smoothed disks of various radii,
angles and distances is monotone in the Euclidean distance of their centers.
Note that this monotonicity isn't strictly required by the theory,
but holds empirically.
"""

def test_wemd_norm(self):
WIDTH = 64
HEIGHT = 64
CENTER_X = WIDTH // 2
CENTER_Y = HEIGHT // 2

# A few disk radii and ray angles to test
RADII = [1, 2, 3, 4, 5, 6, 7]
ANGLES = [0.0, 0.4755, 0.6538, 1.9818, 3.0991, 4.4689, 4.9859, 5.5752]

for radius in RADII:
for angle in ANGLES:
disks = [
_smoothed_disk_image(
CENTER_X + int(k * cos(angle)),
CENTER_Y + int(k * sin(angle)),
radius,
WIDTH,
HEIGHT,
)
for k in range(0, 16, 2)
]
wemd_distances_along_ray = [
wemd_norm(disks[0] - disk) for disk in disks
]
logger.info(f"wemd distances along ray: {wemd_distances_along_ray}")
self.assertTrue(_is_monotone(wemd_distances_along_ray))