Skip to content

Commit ca34247

Browse files
move eval_filter_grid to operators.filters
remove method from ImageSource fix imports no self import cleanup docstring clarity and style fixes
1 parent e0f2194 commit ca34247

File tree

5 files changed

+29
-33
lines changed

5 files changed

+29
-33
lines changed

src/aspire/covariance/covar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from aspire import config
1212
from aspire.image import Image
1313
from aspire.nufft import anufft
14+
from aspire.operators import evaluate_grid_src
1415
from aspire.reconstruction import Estimator, FourierKernel, MeanEstimator
1516
from aspire.utils import (
1617
ensure,
@@ -47,7 +48,7 @@ def compute_kernel(self):
4748
_2L = 2 * self.L
4849

4950
kernel = np.zeros((_2L, _2L, _2L, _2L, _2L, _2L), dtype=self.dtype)
50-
sq_filters_f = self.src.eval_filter_grid(self.L, power=2)
51+
sq_filters_f = evaluate_grid_src(self.src, self.L, power=2)
5152

5253
for i in tqdm(range(0, n, self.batch_size)):
5354
_range = np.arange(i, min(n, i + self.batch_size))

src/aspire/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ScalarFilter,
1414
ScaledFilter,
1515
ZeroFilter,
16+
evaluate_grid_src,
1617
voltage_to_wavelength,
1718
)
1819
from .wemd import wemd_embed, wemd_norm

src/aspire/operators/filters.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,30 @@ def wavelength_to_voltage(wavelength):
3333
) / (2 * 0.978466)
3434

3535

36+
def evaluate_grid_src(src, L, power=1):
37+
"""
38+
Given an ImageSource object, compute the source's unique filters
39+
at the filter_indices specified in its metadata.
40+
:return: an `L x L x len(src.filter_indices)` array containing the evaluated
41+
filters at each gridpoint
42+
"""
43+
grid2d = grid_2d(L, dtype=src.dtype)
44+
omega = np.pi * np.vstack((grid2d["x"].flatten(), grid2d["y"].flatten()))
45+
46+
h = np.empty((omega.shape[-1], len(src.filter_indices)), dtype=src.dtype)
47+
for i, filt in enumerate(src.unique_filters):
48+
idx_k = np.where(src.filter_indices == i)[0]
49+
if len(idx_k) > 0:
50+
filter_values = filt.evaluate(omega)
51+
if power != 1:
52+
filter_values **= power
53+
h[:, idx_k] = np.column_stack((filter_values,) * len(idx_k))
54+
55+
h = np.reshape(h, grid2d["x"].shape + (len(src.filter_indices),))
56+
57+
return h
58+
59+
3660
# TODO: filters should probably be dtyped...
3761
class Filter:
3862
def __init__(self, dim=None, radial=False):
@@ -116,20 +140,6 @@ def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs):
116140

117141
return h
118142

119-
def evaluate_grid_src(self, src, L, power=1):
120-
grid2d = grid_2d(L, dtype=src.dtype)
121-
omega = np.pi * np.vstack((grid2d["x"].flatten(), grid2d["y"].flatten()))
122-
h = np.empty((omega.shape[-1], len(src.filter_indices)), dtype=src.dtype)
123-
for i, filt, in enumerate(src.unique_filters):
124-
idx_k = np.where(src.filter_indices == i)[0]
125-
if len(idx_k) > 0:
126-
filter_values = filt.evaluate(omega)
127-
if power != 1:
128-
filter_values **= power
129-
h[:, idx_k] = np.column_stack((filter_values,)*len(idx_k))
130-
h = np.reshape(h, grid2d["x"].shape + (len(src.filter_indices),))
131-
return h
132-
133143
def dual(self):
134144
return DualFilter(self)
135145

src/aspire/reconstruction/mean.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from scipy.fftpack import fft2
55

66
from aspire.nufft import anufft
7+
from aspire.operators import evaluate_grid_src
78
from aspire.reconstruction import Estimator, FourierKernel
89
from aspire.utils.fft import mdim_ifftshift
910
from aspire.utils.matlab_compat import m_flatten, m_reshape
@@ -16,7 +17,7 @@ class MeanEstimator(Estimator):
1617
def compute_kernel(self):
1718
_2L = 2 * self.L
1819
kernel = np.zeros((_2L, _2L, _2L), dtype=self.dtype)
19-
sq_filters_f = self.src.eval_filter_grid(self.L, power=2)
20+
sq_filters_f = evaluate_grid_src(self.src, self.L, power=2)
2021

2122
for i in range(0, self.n, self.batch_size):
2223
_range = np.arange(i, min(self.n, i + self.batch_size), dtype=int)

src/aspire/source/image.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -358,23 +358,6 @@ def eval_filters(self, im_orig, start=0, num=np.inf, indices=None):
358358

359359
return im
360360

361-
def eval_filter_grid(self, L, power=1):
362-
grid2d = grid_2d(L, dtype=self.dtype)
363-
omega = np.pi * np.vstack((grid2d["x"].flatten(), grid2d["y"].flatten()))
364-
365-
h = np.empty((omega.shape[-1], len(self.filter_indices)), dtype=self.dtype)
366-
for i, filt in enumerate(self.unique_filters):
367-
idx_k = np.where(self.filter_indices == i)[0]
368-
if len(idx_k) > 0:
369-
filter_values = filt.evaluate(omega)
370-
if power != 1:
371-
filter_values **= power
372-
h[:, idx_k] = np.column_stack((filter_values,) * len(idx_k))
373-
374-
h = np.reshape(h, grid2d["x"].shape + (len(self.filter_indices),))
375-
376-
return h
377-
378361
def cache(self):
379362
logger.info("Caching source images")
380363
self._cached_im = self.images(start=0, num=np.inf)

0 commit comments

Comments
 (0)