Skip to content

Commit ce200cc

Browse files
Candidate fix for #533: Whitening applied twice to Simulation objects (#537)
* override ImageSource.whiten() so as to not whiten CTF filters in sim * make _apply_filters accept filters and filter_indices as parameters * formatting imports * removed _apply_unique_filters from sim._images() * create default value for filter_indices * filter_indices must have a default of [0,0,...] * apply_unique_filters -> apply_source_filters * np.zeros() instead of np.array() * requested change and 549 bugfix * formatting * cleaner solution: np.at_least1D * spelling * comments * format
1 parent 902a9c0 commit ce200cc

File tree

2 files changed

+38
-18
lines changed

2 files changed

+38
-18
lines changed

src/aspire/source/image.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def states(self, values):
145145

146146
@property
147147
def filter_indices(self):
148-
return self.get_metadata("__filter_indices")
148+
return np.atleast_1d(self.get_metadata("__filter_indices"))
149149

150150
@filter_indices.setter
151151
def filter_indices(self, indices):
@@ -345,17 +345,18 @@ def _images(self, start=0, num=np.inf, indices=None):
345345
"Subclasses should implement this and return an Image object"
346346
)
347347

348-
def _apply_filters(self, im_orig, start=0, num=np.inf, indices=None):
348+
def _apply_filters(
349+
self,
350+
im_orig,
351+
filters,
352+
indices,
353+
):
349354
"""
350-
For each image in `im_orig` specified by start, num, or indices,
351-
the unique_filters associated with the corresponding index in the
352-
`ImageSource` are applied. The images are then returned as an `Image`
353-
stack.
355+
For images in `im_orig`, `filters` associated with the corresponding
356+
index in the supplied `indices` are applied. The images are then returned as an `Image` stack.
354357
:param im_orig: An `Image` object
355-
:param start: Starting index of images in `im_orig`.
356-
:param num: Number of images to work on, starting at `start`.
357-
:param indices: A numpy array of image indices. If specified,`start` and `num` are ignored.
358-
:return: An `Image` instance with the unique filters of the source applied at the given indices.
358+
:param filters: A list of `Filter` objects
359+
:param indices: A list of indices indicating the corresponding filter in `filters`
359360
"""
360361
if not isinstance(im_orig, Image):
361362
logger.warning(
@@ -366,16 +367,20 @@ def _apply_filters(self, im_orig, start=0, num=np.inf, indices=None):
366367

367368
im = im_orig.copy()
368369

369-
if indices is None:
370-
indices = np.arange(start, min(start + num, self.n))
371-
372-
for i, filt in enumerate(self.unique_filters):
373-
idx_k = np.where(self.filter_indices[indices] == i)[0]
370+
for i, filt in enumerate(filters):
371+
idx_k = np.where(indices == i)[0]
374372
if len(idx_k) > 0:
375373
im[idx_k] = Image(im[idx_k]).filter(filt).asnumpy()
376374

377375
return im
378376

377+
def _apply_source_filters(self, im_orig, indices):
378+
return self._apply_filters(
379+
im_orig,
380+
self.unique_filters,
381+
self.filter_indices[indices],
382+
)
383+
379384
def cache(self):
380385
logger.info("Caching source images")
381386
self._cached_im = self.images(start=0, num=np.inf)
@@ -531,7 +536,7 @@ def im_backward(self, im, start):
531536
all_idx = np.arange(start, min(start + num, self.n))
532537
im *= self.amplitudes[all_idx, np.newaxis, np.newaxis]
533538
im = im.shift(-self.offsets[all_idx, :])
534-
im = self._apply_filters(im, start=start, num=num)
539+
im = self._apply_source_filters(im, all_idx)
535540

536541
vol = im.backproject(self.rots[start : start + num, :, :])[0]
537542

@@ -553,7 +558,7 @@ def vol_forward(self, vol, start, num):
553558
logger.warning(f"Volume.dtype {vol.dtype} inconsistent with {self.dtype}")
554559

555560
im = vol.project(0, self.rots[all_idx, :, :])
556-
im = self._apply_filters(im, start, num)
561+
im = self._apply_source_filters(im, all_idx)
557562
im = im.shift(self.offsets[all_idx, :])
558563
im *= self.amplitudes[all_idx, np.newaxis, np.newaxis]
559564
return im

src/aspire/source/simulation.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import logging
23

34
import numpy as np
@@ -84,12 +85,17 @@ def __init__(
8485
if unique_filters is None:
8586
unique_filters = []
8687
self.unique_filters = unique_filters
88+
# sim_filters must be a deep copy so that it is not changed
89+
# when unique_filters is changed
90+
self.sim_filters = copy.deepcopy(unique_filters)
8791

8892
# Create filter indices and fill the metadata based on unique filters
8993
if unique_filters:
9094
if filter_indices is None:
9195
filter_indices = randi(len(unique_filters), n, seed=seed) - 1
9296
self.filter_indices = filter_indices
97+
else:
98+
self.filter_indices = np.zeros(n)
9399

94100
self.offsets = offsets
95101
self.amplitudes = amplitudes
@@ -134,7 +140,9 @@ def _images(self, start=0, num=np.inf, indices=None, enable_noise=True):
134140

135141
im = self.projections(start=start, num=num, indices=indices)
136142

137-
im = self._apply_filters(im, start=start, num=num, indices=indices)
143+
# apply original CTF distortion to image
144+
im = self._apply_sim_filters(im, indices)
145+
138146
im = im.shift(self.offsets[indices, :])
139147

140148
im *= self.amplitudes[indices].reshape(len(indices), 1, 1).astype(self.dtype)
@@ -144,6 +152,13 @@ def _images(self, start=0, num=np.inf, indices=None, enable_noise=True):
144152

145153
return im
146154

155+
def _apply_sim_filters(self, im, indices):
156+
return self._apply_filters(
157+
im,
158+
self.sim_filters,
159+
self.filter_indices[indices],
160+
)
161+
147162
def vol_coords(self, mean_vol=None, eig_vols=None):
148163
"""
149164
Coordinates of simulation volumes in a given basis

0 commit comments

Comments
 (0)