Skip to content

Commit 5d2fe6a

Browse files
committed
draft extension of grid_2d and image.filter to rectangle
1 parent 19a3a04 commit 5d2fe6a

File tree

4 files changed

+1346
-1279
lines changed

4 files changed

+1346
-1279
lines changed

src/aspire/image/image.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -588,10 +588,6 @@ def filter(self, filter):
588588
:param filter: An object of type `Filter`.
589589
:return: A new filtered `Image` object.
590590
"""
591-
if not self._is_square:
592-
raise NotImplementedError(
593-
"`Image.filter` is not currently implemented for non-square images."
594-
)
595591

596592
original_stack_shape = self.stack_shape
597593

@@ -603,9 +599,14 @@ def filter(self, filter):
603599
#
604600
# Second note, filter dtype may not match image dtype.
605601
filter_values = xp.asarray(
606-
filter.evaluate_grid(self.resolution), dtype=self.dtype
602+
filter.evaluate_grid(self.shape[-2:]), dtype=self.dtype
607603
)
608604

605+
# sanity check
606+
assert (
607+
filter_values.shape == im._data.shape[-2:]
608+
), f"{filter_values.shape} != {im._data.shape[:-2]}"
609+
609610
# Convolve
610611
im_f = fft.centered_fft2(xp.asarray(im._data))
611612
im_f = filter_values * im_f

src/aspire/utils/coor_trans.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,22 @@ def grid_2d(n, shifted=False, normalized=True, indexing="yx", dtype=np.float32):
9898
Generate two dimensional grid.
9999
100100
:param n: the number of grid points in each dimension.
101+
May be a single integer value or 2-tuple of integers.
101102
:param shifted: shifted by half of grid or not when n is even.
102103
:param normalized: normalize the grid in the range of (-1, 1) or not.
103104
:param indexing: 'yx' (C) or 'xy' (F), defaulting to 'yx'.
104105
See https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html
105106
:return: the rectangular and polar coordinates of all grid points.
106107
"""
107108

108-
grid = _mgrid_slice(n, shifted, normalized)
109-
y, x = np.mgrid[grid, grid].astype(dtype)
109+
if isinstance(n, int):
110+
n = (n, n)
111+
rows, cols = n
112+
113+
rows = _mgrid_slice(rows, shifted, normalized)
114+
cols = _mgrid_slice(cols, shifted, normalized)
115+
116+
y, x = np.mgrid[rows, cols].astype(dtype)
110117
if indexing == "xy":
111118
x, y = y, x
112119
elif indexing != "yx":

tests/test_filters.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pytest
77

8+
from aspire.image import Image
89
from aspire.operators import (
910
ArrayFilter,
1011
CTFFilter,
@@ -16,7 +17,7 @@
1617
ScaledFilter,
1718
ZeroFilter,
1819
)
19-
from aspire.utils import utest_tolerance
20+
from aspire.utils import gaussian_2d, grid_2d, utest_tolerance
2021

2122
DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data")
2223

@@ -432,3 +433,40 @@ def test_ctf_reference():
432433
# Test we're within 1%.
433434
# There are minor differences in the formulas for wavelength and grids.
434435
np.testing.assert_allclose(h, ref_h, rtol=0.01)
436+
437+
438+
def test_rectangular_ctf():
439+
"""
440+
Compare a truncated rectangular CTF application with the
441+
application of CTF to a full square image.
442+
"""
443+
# Configure square and truncated rectangle size
444+
L = 128
445+
rows, cols = 96, L
446+
assert rows <= L and cols <= L and min(rows, cols) < L
447+
448+
# Create a test image of a disk
449+
# A = gaussian_2d(size=L, mu=(0,-L//10), sigma=L//8, dtype=np.float64)
450+
A = gaussian_2d(size=L, mu=(0, -L // 10), sigma=L // 32, dtype=np.float64)
451+
452+
full_img = Image(A)
453+
truncated_img = Image(A[:rows, :cols])
454+
455+
# Create a CTFFilter
456+
ctf_filter = CTFFilter(pixel_size=2)
457+
458+
# Apply to both Image instances
459+
full_img_with_ctf = full_img.filter(ctf_filter)
460+
truncated_img_with_ctf = truncated_img.filter(ctf_filter)
461+
462+
# Truncate the full square result
463+
full_img_with_ctf_truncated = full_img_with_ctf.asnumpy()[:, :rows, :cols]
464+
465+
# Create mask for circular convolution effects
466+
mask = (grid_2d(L, normalized=True)["r"] < 0.5)[:rows, :cols]
467+
468+
# Compare, we should be the same up to masked off differences in
469+
# circular convolution wrap around.
470+
np.testing.assert_allclose(
471+
truncated_img_with_ctf * mask, full_img_with_ctf_truncated * mask, atol=1e-6
472+
)

0 commit comments

Comments
 (0)