Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 71 additions & 6 deletions src/aspire/operators/filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import logging
import math

import numpy as np
Expand All @@ -9,6 +10,8 @@
from aspire.utils.filter_to_fb_mat import filter_to_fb_mat
from aspire.utils.matlab_compat import m_reshape

logger = logging.getLogger(__name__)


def voltage_to_wavelength(voltage):
"""
Expand Down Expand Up @@ -92,6 +95,18 @@ def scale(self, c=1):
return ScaledFilter(self, c)

def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs):
"""
Generates a two dimensional grid with prescribed dtype,
yielding the values (omega) which are then evaluated by
the filter's evaluate method.

Passes arbritrary args and kwargs down to self.evaluate method.

:param L: Number of grid points (L by L).
:param dtype: dtype of grid, defaults np.float32.
:return: Filter values at omega's points.
"""

grid2d = grid_2d(L, dtype=dtype)
omega = np.pi * np.vstack((grid2d["x"].flatten("F"), grid2d["y"].flatten("F")))
h = self.evaluate(omega, *args, **kwargs)
Expand Down Expand Up @@ -153,6 +168,19 @@ def __init__(self, filter, power=1):
def _evaluate(self, omega):
return self._filter.evaluate(omega) ** self._power

def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs):
"""
Calls the provided filter's evaluate_grid method in case there is an optimization.

If no optimized method is provided, falls back to base `evaluate_grid`.

See `Filter.evaluate_grid` for usage.
"""

return (
self._filter.evaluate_grid(L, dtype=dtype, *args, **kwargs) ** self._power
)


class LambdaFilter(Filter):
"""
Expand Down Expand Up @@ -241,27 +269,64 @@ def __init__(self, xfer_fn_array):
self.xfer_fn_array = xfer_fn_array

def _evaluate(self, omega):
sz = self.sz

_input_pts = tuple(np.linspace(1, x, x) for x in self.xfer_fn_array.shape)

# TODO: This part could do with some documentation - not intuitive!
temp = np.array(sz)[:, np.newaxis]
temp = np.array(self.sz)[:, np.newaxis]
omega = (omega / (2 * np.pi)) * temp
omega += np.floor(temp / 2) + 1

# Emulating the behavior of interpn(V,X1q,X2q,X3q,...) in MATLAB
_input_pts = tuple(list(range(1, x + 1)) for x in self.xfer_fn_array.shape)
# The original MATLAB was using 'linear' and zero fill.
# We will use 'linear' but fill_value=None which will extrapolate
# for values slightly outside the interpolation grid bounds.
interpolator = RegularGridInterpolator(
_input_pts, self.xfer_fn_array, bounds_error=False, fill_value=0
_input_pts,
self.xfer_fn_array,
method="linear",
bounds_error=False,
fill_value=None,
)

result = interpolator(
# Split omega into input arrays and stack depth-wise because that's how
# the interpolator wants it
np.dstack(np.split(omega, len(sz)))
np.dstack(np.split(omega, len(self.sz)))
)

# Result is 1 x np.prod(sz) in shape; convert to a 1-d vector
# Result is 1 x np.prod(self.sz) in shape; convert to a 1-d vector
result = np.squeeze(result, 0)

return result

def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs):
"""
Optimized evaluate_grid method for ArrayFilter.

If evaluate_grid is called with a resolution L that matches
the transfer function `xfer_fn_array` resolution,
we do not need to generate a grid, setup interpolation, and
evaluate by interpolation. We can instead use the transfer
function directly.

In the case the grid is not a match, we fall back to the
base `evaluate_grid` implementation.

See Filter.evaluate_grid for usage.
"""

if all(dim == L for dim in self.xfer_fn_array.shape):
logger.debug(
"Size of transfer function matches evaluate_grid size L exactly,"
" skipping grid generation and interpolation."
)
res = self.xfer_fn_array
else:
# Otherwise call parent code to generate a grid then evaluate.
res = super().evaluate_grid(L, dtype=dtype, *args, **kwargs)
return res


class ScalarFilter(Filter):
def __init__(self, dim=None, value=1):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_preprocess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,32 @@ def testWhiten(self):
# correlation matrix should be close to identity
self.assertTrue(np.allclose(np.eye(2), corr_coef, atol=1e-1))

def testWhiten2(self):
# Excercises missing cases using odd image resolutions with filter.
# Relates to GitHub issue #401.
# Otherwise this is the same as testWhiten, though the accuracy
# (atol) for odd resolutions seems slightly worse.
L = self.L - 1
assert L % 2 == 1, "Test resolution should be odd"

sim = Simulation(
L=L,
n=self.n,
unique_filters=[
RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)
],
noise_filter=self.noise_filter,
dtype=self.dtype,
)
noise_estimator = AnisotropicNoiseEstimator(sim)
sim.whiten(noise_estimator.filter)
imgs_wt = sim.images(start=0, num=self.n).asnumpy()

corr_coef = np.corrcoef(imgs_wt[:, L - 1, L - 1], imgs_wt[:, L - 2, L - 1])

# Correlation matrix should be close to identity
self.assertTrue(np.allclose(np.eye(2), corr_coef, atol=2e-1))

def testInvertContrast(self):
sim1 = self.sim
imgs1 = sim1.images(start=0, num=128)
Expand Down