Skip to content

Commit 19a3a04

Browse files
committed
initial MicrographSource/Image extension to rect shapes
1 parent 8ad361b commit 19a3a04

File tree

6 files changed

+87
-55
lines changed

6 files changed

+87
-55
lines changed

src/aspire/image/image.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,20 +202,19 @@ def __init__(self, data, pixel_size=None, dtype=None):
202202
else:
203203
self.dtype = np.dtype(dtype)
204204

205-
if not data.shape[-1] == data.shape[-2]:
206-
raise ValueError("Only square ndarrays are supported.")
207-
208205
self._data = data.astype(self.dtype, copy=False)
209206
self.ndim = self._data.ndim
210207
self.shape = self._data.shape
211208
self.stack_ndim = self._data.ndim - 2
212209
self.stack_shape = self._data.shape[:-2]
213210
self.n_images = np.prod(self.stack_shape)
214-
self.resolution = self._data.shape[-1]
215211
self.pixel_size = None
216212
if pixel_size is not None:
217213
self.pixel_size = float(pixel_size)
218214

215+
self._is_square = data.shape[-1] == data.shape[-2]
216+
self.resolution = self._data.shape[-1] # XXXXX
217+
219218
# Numpy interop
220219
# https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol
221220
self.__array_interface__ = self._data.__array_interface__
@@ -233,6 +232,12 @@ def project(self, angles):
233232
:return: Radon transform of the Image Stack.
234233
:rtype: Ndarray (stack size, number of angles, image resolution)
235234
"""
235+
236+
if not self._is_square:
237+
raise NotImplementedError(
238+
"`Image.project` is not currently implemented for non-square images."
239+
)
240+
236241
# number of points to sample on radial line in polar grid
237242
n_points = self.resolution
238243
original_stack = self.stack_shape
@@ -308,19 +313,19 @@ def stack_reshape(self, *args):
308313
)
309314

310315
def __add__(self, other):
311-
if isinstance(other, Image):
316+
if isinstance(other, self.__class__):
312317
other = other._data
313318

314319
return self.__class__(self._data + other, pixel_size=self.pixel_size)
315320

316321
def __sub__(self, other):
317-
if isinstance(other, Image):
322+
if isinstance(other, self.__class__):
318323
other = other._data
319324

320325
return self.__class__(self._data - other, pixel_size=self.pixel_size)
321326

322327
def __mul__(self, other):
323-
if isinstance(other, Image):
328+
if isinstance(other, self.__class__):
324329
other = other._data
325330

326331
return self.__class__(self._data * other, pixel_size=self.pixel_size)
@@ -384,7 +389,7 @@ def __repr__(self):
384389
px_msg = f" with pixel_size={self.pixel_size} angstroms."
385390

386391
msg = f"{self.n_images} {self.dtype} images arranged as a {self.stack_shape} stack"
387-
msg += f" each of size {self.resolution}x{self.resolution}{px_msg}"
392+
msg += f" each of size {self.shape[-2:]}{px_msg}"
388393
return msg
389394

390395
def asnumpy(self):
@@ -441,6 +446,12 @@ def legacy_whiten(self, psd, delta):
441446
and which to set to zero. By default all `sqrt(psd)` values
442447
less than `delta` are zeroed out in the whitening filter.
443448
"""
449+
450+
if not self._is_square:
451+
raise NotImplementedError(
452+
"`Image.legacy_whiten` is not currently implemented for non-square images."
453+
)
454+
444455
n = self.n_images
445456
L = self.resolution
446457
L_half = L // 2
@@ -577,6 +588,11 @@ def filter(self, filter):
577588
:param filter: An object of type `Filter`.
578589
:return: A new filtered `Image` object.
579590
"""
591+
if not self._is_square:
592+
raise NotImplementedError(
593+
"`Image.filter` is not currently implemented for non-square images."
594+
)
595+
580596
original_stack_shape = self.stack_shape
581597

582598
im = self.stack_reshape(-1)
@@ -667,8 +683,8 @@ def _load_raw(filepath, dtype=None):
667683

668684
return im, pixel_size
669685

670-
@staticmethod
671-
def load(filepath, dtype=None):
686+
@classmethod
687+
def load(cls, filepath, dtype=None):
672688
"""
673689
Load raw data from supported files.
674690
@@ -683,7 +699,7 @@ def load(filepath, dtype=None):
683699
im, pixel_size = Image._load_raw(filepath, dtype=dtype)
684700

685701
# Return as Image instance
686-
return Image(im, pixel_size=pixel_size)
702+
return cls(im, pixel_size=pixel_size)
687703

688704
def _im_translate(self, shifts):
689705
"""
@@ -699,6 +715,10 @@ def _im_translate(self, shifts):
699715
Alternatively, it can be a row vector of length 2, in which case the same shifts is applied to each image.
700716
:return: The images translated by the shifts, with periodic boundaries.
701717
"""
718+
if not self._is_square:
719+
raise NotImplementedError(
720+
"`Image._im_translate` is not currently implemented for non-square images."
721+
)
702722

703723
if shifts.ndim == 1:
704724
shifts = shifts[np.newaxis, :]
@@ -769,6 +789,10 @@ def backproject(self, rot_matrices, symmetry_group=None, zero_nyquist=True):
769789
770790
:return: Volume instance corresonding to the backprojected images.
771791
"""
792+
if not self._is_square:
793+
raise NotImplementedError(
794+
"`Image.legacy_whiten` is not currently implemented for non-square images."
795+
)
772796

773797
if self.stack_ndim > 1:
774798
raise NotImplementedError(
@@ -886,6 +910,10 @@ def frc(self, other, cutoff=None, method="fft", plot=False):
886910
where `estimated_resolution` is in angstrom
887911
and FRC is a Numpy array of correlations.
888912
"""
913+
if not self._is_square:
914+
raise NotImplementedError(
915+
"`Image.frc` is not currently implemented for non-square images."
916+
)
889917

890918
if not isinstance(other, Image):
891919
raise TypeError(

src/aspire/source/micrograph.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import warnings
34
from abc import ABC, abstractmethod
45
from glob import glob
56
from pathlib import Path
@@ -20,7 +21,12 @@ class MicrographSource(ABC):
2021
def __init__(self, micrograph_count, micrograph_size, dtype, pixel_size=None):
2122
""" """
2223
self.micrograph_count = int(micrograph_count)
23-
self.micrograph_size = int(micrograph_size)
24+
# Expand single integer to 2-tuple
25+
if isinstance(micrograph_size, int):
26+
micrograph_size = (micrograph_size,) * 2
27+
if len(micrograph_size) != 2:
28+
raise ValueError("`micrograph_size` should be a integer or 2-tuple")
29+
self.micrograph_size = tuple(micrograph_size)
2430
self.dtype = np.dtype(dtype)
2531
if pixel_size is not None:
2632
pixel_size = float(pixel_size)
@@ -34,7 +40,7 @@ def __repr__(self):
3440
3541
:return: Returns a string description of instance.
3642
"""
37-
return f"{self.__class__.__name__} with {self.micrograph_count} {self.dtype.name} micrographs of size {self.micrograph_size}x{self.micrograph_size}"
43+
return f"{self.__class__.__name__} with {self.micrograph_count} {self.dtype.name} micrographs of size {self.micrograph_size}"
3844

3945
def __len__(self):
4046
"""
@@ -142,14 +148,14 @@ def __init__(self, micrographs, dtype=None, pixel_size=None):
142148
if micrographs.ndim == 2:
143149
micrographs = micrographs[None, :, :]
144150

145-
if micrographs.ndim != 3 or (micrographs.shape[-2] != micrographs.shape[-1]):
151+
if micrographs.ndim != 3:
146152
raise NotImplementedError(
147-
f"Incompatible `micrographs` shape {micrographs.shape}, expects (count, L, L)"
153+
f"Incompatible `micrographs` shape {micrographs.shape}, expects 2D or 3D array."
148154
)
149155

150156
super().__init__(
151157
micrograph_count=micrographs.shape[0],
152-
micrograph_size=micrographs.shape[-1],
158+
micrograph_size=micrographs.shape[-2:],
153159
dtype=dtype or micrographs.dtype,
154160
pixel_size=pixel_size,
155161
)
@@ -201,15 +207,15 @@ def __init__(self, micrographs_path, dtype=None, pixel_size=None):
201207

202208
# Load the first micrograph to infer shape/type
203209
# Size will be checked during on-the-fly loading of subsequent micrographs.
204-
micrograph0 = Image.load(self.micrograph_files[0])
205-
if micrograph0.pixel_size is not None and micrograph0.pixel_size != pixel_size:
206-
raise ValueError(
207-
f"Mismatched pixel size. {micrograph0.pixel_size} angstroms defined in {self.micrograph_files[0]}, but provided {pixel_size} angstroms."
208-
)
210+
micrograph0, _pixel_size = Image._load_raw(self.micrograph_files[0])
211+
# Compare with user provided pixel size
212+
if pixel_size is not None and _pixel_size != pixel_size:
213+
msg = f"Mismatched pixel size. {_pixel_size} angstroms defined in {self.micrograph_files[0]}, but provided {pixel_size} angstroms."
214+
warnings.warn(msg, UserWarning, stacklevel=2)
209215

210216
super().__init__(
211217
micrograph_count=len(self.micrograph_files),
212-
micrograph_size=micrograph0.resolution,
218+
micrograph_size=micrograph0.shape[-2:],
213219
dtype=dtype or micrograph0.dtype,
214220
pixel_size=pixel_size,
215221
)
@@ -265,28 +271,27 @@ def _images(self, indices):
265271
# Initialize empty result
266272
n_micrographs = len(indices)
267273
micrographs = np.empty(
268-
(n_micrographs, self.micrograph_size, self.micrograph_size),
274+
(n_micrographs, *self.micrograph_size),
269275
dtype=self.dtype,
270276
)
271277
for i, ind in enumerate(indices):
272278
# Load the micrograph image from file
273-
micrograph = Image.load(self.micrograph_files[ind])
279+
micrograph, _pixel_size = Image._load_raw(self.micrograph_files[ind])
280+
274281
# Assert size
275-
if micrograph.resolution != self.micrograph_size:
282+
if micrograph.shape != self.micrograph_size:
276283
raise NotImplementedError(
277284
f"Micrograph {ind} has inconsistent shape {micrograph.shape},"
278-
f" expected {(self.micrograph_size, self.micrograph_size)}."
285+
f" expected {self.micrograph_size}."
279286
)
287+
288+
# Continually compare with initial pixel_size
289+
if _pixel_size is not None and _pixel_size != self.pixel_size:
290+
msg = f"Mismatched pixel size. {micrograph.pixel_size} angstroms defined in {self.micrograph_files[ind]}, but provided {self.pixel_size} angstroms."
291+
warnings.warn(msg, UserWarning, stacklevel=2)
292+
280293
# Assign to array, implicitly performs casting to dtype
281-
micrographs[i] = micrograph.asnumpy()
282-
# Assert pixel_size
283-
if (
284-
micrograph.pixel_size is not None
285-
and micrograph.pixel_size != self.pixel_size
286-
):
287-
raise ValueError(
288-
f"Mismatched pixel size. {micrograph.pixel_size} angstroms defined in {self.micrograph_files[ind]}, but provided {self.pixel_size} angstroms."
289-
)
294+
micrographs[i] = micrograph
290295

291296
return Image(micrographs, pixel_size=self.pixel_size)
292297

@@ -313,7 +318,7 @@ def __init__(
313318
314319
:param volume: `Volume` instance to be used in `Simulation`.
315320
An `(L,L,L)` `Volume` will generate `(L,L)` particle images.
316-
:param micrograph_size: Size of micrograph in pixels, defaults to 4096.
321+
:param micrograph_size: Size of micrograph in pixels as integer or 2-tuple. Defaults to 4096.
317322
:param micrograph_count: Number of micrographs to generate (integer). Defaults to 1.
318323
:param particles_per_micrograph: The amount of particles generated for each micrograph. Defaults to 10.
319324
:param particle_amplitudes: Optional, amplitudes to pass to `Simulation`.
@@ -354,7 +359,7 @@ def __init__(
354359

355360
self.noise_adder = noise_adder
356361

357-
if self.particle_box_size > micrograph_size:
362+
if self.particle_box_size > max(self.micrograph_size):
358363
raise ValueError(
359364
"The micrograph size must be larger or equal to the `particle_box_size`."
360365
)
@@ -415,7 +420,7 @@ def __init__(
415420
else:
416421
if (
417422
boundary < (-self.particle_box_size // 2)
418-
or boundary > self.micrograph_size // 2
423+
or boundary > max(self.micrograph_size) // 2
419424
):
420425
raise ValueError("Illegal boundary value.")
421426
self.boundary = boundary
@@ -505,8 +510,8 @@ def _set_mask(self):
505510
"""
506511
self._mask = np.full(
507512
(
508-
int(self.micrograph_size + 2 * self.pad),
509-
int(self.micrograph_size + 2 * self.pad),
513+
int(self.micrograph_size[0] + 2 * self.pad),
514+
int(self.micrograph_size[1] + 2 * self.pad),
510515
),
511516
False,
512517
dtype=bool,
@@ -547,7 +552,7 @@ def _clean_images(self, indices):
547552
# Initialize empty micrograph
548553
n_micrographs = len(indices)
549554
clean_micrograph = np.zeros(
550-
(n_micrographs, self.micrograph_size, self.micrograph_size),
555+
(n_micrographs, *self.micrograph_size),
551556
dtype=self.dtype,
552557
)
553558
# Pad the micrograph
@@ -579,8 +584,8 @@ def _clean_images(self, indices):
579584
)
580585
clean_micrograph = clean_micrograph[
581586
:,
582-
self.pad : self.micrograph_size + self.pad,
583-
self.pad : self.micrograph_size + self.pad,
587+
self.pad : self.micrograph_size[0] + self.pad,
588+
self.pad : self.micrograph_size[1] + self.pad,
584589
]
585590
return Image(clean_micrograph, pixel_size=self.pixel_size)
586591

tests/test_array_image_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def testArrayImageSourceNumpyError(self):
8181

8282
# Test we raise with expected message from getter.
8383
with raises(RuntimeError, match=r"Creating Image object from Numpy.*"):
84-
_ = ArrayImageSource(np.empty((3, 2, 1)))
84+
_ = ArrayImageSource(np.empty((1)))
8585

8686
def testArrayImageSourceAngGetterError(self):
8787
"""

tests/test_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def testRepr(get_mdim_images):
8888

8989

9090
def testNonSquare():
91-
"""Test that an irregular Image array raises."""
92-
with raises(ValueError, match=r".* square .*"):
93-
_ = Image(np.empty((4, 5)))
91+
"""Test that an irregular Image array does not raise."""
92+
_ = Image(np.empty((4, 5)))
93+
_ = Image(np.empty((3, 4, 5)))
9494

9595

9696
def testImShift(get_images, dtype):

tests/test_micrograph_simulation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def test_micrograph_source_has_correct_values(vol_fixture, micrograph_fixture):
8585
assert v.resolution == m.particle_box_size
8686
assert v == m.simulation.vols
8787
assert len(m) == m.micrograph_count
88-
assert m.clean_images[0].shape[1] == m.micrograph_size
89-
assert m.clean_images[0].shape[2] == m.micrograph_size
88+
assert m.clean_images[0].shape[1] == m.micrograph_size[0]
89+
assert m.clean_images[0].shape[2] == m.micrograph_size[1]
9090
assert (
9191
repr(m)
92-
== f"{m.__class__.__name__} with {m.micrograph_count} {m.dtype.name} micrographs of size {m.micrograph_size}x{m.micrograph_size}"
92+
== f"{m.__class__.__name__} with {m.micrograph_count} {m.dtype.name} micrographs of size {m.micrograph_size}"
9393
)
9494
_ = m.clean_images[:]
9595
_ = m.images[:]
@@ -136,9 +136,9 @@ def test_micrograph_centers_match(micrograph_fixture):
136136
for i, center in enumerate(centers):
137137
if (
138138
center[0] >= 0
139-
and center[0] < m.micrograph_size
139+
and center[0] < m.micrograph_size[0]
140140
and center[1] >= 0
141-
and center[1] < m.micrograph_size
141+
and center[1] < m.micrograph_size[1]
142142
):
143143
assert m.clean_images[i // m.particles_per_micrograph].asnumpy()[0][
144144
tuple(center)

tests/test_micrograph_source.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,11 @@ def test_wrong_dim_micrograph_source():
271271

272272
def test_rectangular_micrograph_source_array():
273273
"""
274-
Test non-square micrograph source raises.
274+
Test non-square micrograph source does not raises.
275275
"""
276276
# Test with Numpy array input
277277
imgs_np = np.empty((3, 7, 8))
278-
with pytest.raises(RuntimeError, match=r"Incompatible.*"):
279-
ArrayMicrographSource(imgs_np)
278+
_ = ArrayMicrographSource(imgs_np)
280279

281280

282281
def test_rectangular_micrograph_source_files():

0 commit comments

Comments
 (0)