@@ -107,7 +107,7 @@ def states(self, values):
107107
108108 @property
109109 def filter_indices (self ):
110- return self .get_metadata ("__filter_indices" )
110+ return np . atleast_1d ( self .get_metadata ("__filter_indices" ) )
111111
112112 @filter_indices .setter
113113 def filter_indices (self , indices ):
@@ -279,17 +279,18 @@ def _images(self, start=0, num=np.inf, indices=None):
279279 "Subclasses should implement this and return an Image object"
280280 )
281281
282- def _apply_filters (self , im_orig , start = 0 , num = np .inf , indices = None ):
282+ def _apply_filters (
283+ self ,
284+ im_orig ,
285+ filters ,
286+ indices ,
287+ ):
283288 """
284- For each image in `im_orig` specified by start, num, or indices,
285- the unique_filters associated with the corresponding index in the
286- `ImageSource` are applied. The images are then returned as an `Image`
287- stack.
289+ For images in `im_orig`, `filters` associated with the corresponding
290+ index in the supplied `indices` are applied. The images are then returned as an `Image` stack.
288291 :param im_orig: An `Image` object
289- :param start: Starting index of images in `im_orig`.
290- :param num: Number of images to work on, starting at `start`.
291- :param indices: A numpy array of image indices. If specified,`start` and `num` are ignored.
292- :return: An `Image` instance with the unique filters of the source applied at the given indices.
292+ :param filters: A list of `Filter` objects
293+ :param indices: A list of indices indicating the corresponding filter in `filters`
293294 """
294295 if not isinstance (im_orig , Image ):
295296 logger .warning (
@@ -300,16 +301,20 @@ def _apply_filters(self, im_orig, start=0, num=np.inf, indices=None):
300301
301302 im = im_orig .copy ()
302303
303- if indices is None :
304- indices = np .arange (start , min (start + num , self .n ))
305-
306- for i , filt in enumerate (self .unique_filters ):
307- idx_k = np .where (self .filter_indices [indices ] == i )[0 ]
304+ for i , filt in enumerate (filters ):
305+ idx_k = np .where (indices == i )[0 ]
308306 if len (idx_k ) > 0 :
309307 im [idx_k ] = Image (im [idx_k ]).filter (filt ).asnumpy ()
310308
311309 return im
312310
311+ def _apply_source_filters (self , im_orig , indices ):
312+ return self ._apply_filters (
313+ im_orig ,
314+ self .unique_filters ,
315+ self .filter_indices [indices ],
316+ )
317+
313318 def cache (self ):
314319 logger .info ("Caching source images" )
315320 self ._cached_im = self .images (start = 0 , num = np .inf )
@@ -465,7 +470,7 @@ def im_backward(self, im, start):
465470 all_idx = np .arange (start , min (start + num , self .n ))
466471 im *= self .amplitudes [all_idx , np .newaxis , np .newaxis ]
467472 im = im .shift (- self .offsets [all_idx , :])
468- im = self ._apply_filters (im , start = start , num = num )
473+ im = self ._apply_source_filters (im , all_idx )
469474
470475 vol = im .backproject (self .rots [start : start + num , :, :])[0 ]
471476
@@ -487,7 +492,7 @@ def vol_forward(self, vol, start, num):
487492 logger .warning (f"Volume.dtype { vol .dtype } inconsistent with { self .dtype } " )
488493
489494 im = vol .project (0 , self .rots [all_idx , :, :])
490- im = self ._apply_filters (im , start , num )
495+ im = self ._apply_source_filters (im , all_idx )
491496 im = im .shift (self .offsets [all_idx , :])
492497 im *= self .amplitudes [all_idx , np .newaxis , np .newaxis ]
493498 return im
0 commit comments