@@ -28,7 +28,7 @@ def __init__(
2828 """
2929 :param alignment_basis: Basis to be used during alignment (eg FSPCA)
3030 :param source: Source of original images.
31- :param composite_basis: Basis to be used during class average composition (eg FFB2D)
31+ :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/ FFB2D)
3232 :param dtype: Numpy dtype to be used during alignment.
3333 """
3434
@@ -89,43 +89,24 @@ def align(self, classes, reflections, basis_coefficients):
8989 :returns: Image instance (stack of images)
9090 """
9191
92- def _images (self , cls ):
92+ def _images (self , cls , src = None ):
9393 """
9494 Util to return images as an array for class k (provided as array `cls` ),
9595 preserving the class/nbor order.
9696
97- :param cls: An iterable (0/1-D array or list) that holds the indices of images to align. In Class Averaging, this would be a class.
97+ :param cls: An iterable (0/1-D array or list) that holds the indices of images to align.
98+ In Class Averaging, this would be a class.
99+ :param src: Optionally overridee the src, for example, if you want to use a different
100+ source for a certain operation (ie aignment).
98101 """
102+ src = src or self .src
99103
100104 n_nbor = cls .shape [- 1 ] # Includes zero'th neighbor
101105
102- # Get the images. We'll loop over the source in batches.
103- # Note one day when the Source.images is more flexible,
104- # this code would mostly go away.
105- images = np .empty ((n_nbor , self .src .L , self .src .L ), dtype = self .dtype )
106-
107- # We want to only process batches that actually
108- # contain images for this class.
109- # First compute the batches' indices.
110- for start in range (0 , self .src .n + 1 , self .batch_size ):
111- # First cook up the batch boundaries
112- end = start + self .batch_size
113- # UBound, these are inclusive bounds
114- start = min (start , self .src .n - 1 )
115- end = min (end , self .src .n - 1 )
116- num = end - start + 1
117-
118- # Second, loop over the cls members
119- image_batch = None
120- for i , index in enumerate (cls ):
121- # Check if the member is in this chunk
122- if start <= index <= end :
123- # Get and cache this image_batch on first hit.
124- if image_batch is None :
125- image_batch = self .src .images (start , num )
126- # Translate the cls's index into this batch's
127- batch_index = index % self .batch_size
128- images [i ] = image_batch [batch_index ]
106+ images = np .empty ((n_nbor , src .L , src .L ), dtype = self .dtype )
107+
108+ for i , index in enumerate (cls ):
109+ images [i ] = src .images (index , 1 ).asnumpy ()
129110
130111 return images
131112
@@ -429,20 +410,44 @@ def __init__(
429410 alignment_basis ,
430411 source ,
431412 composite_basis = None ,
413+ alignment_source = None ,
432414 diagnostics = False ,
433415 batch_size = 512 ,
434416 dtype = None ,
435417 ):
436418 """
437- :param alignment_basis: Basis to be used during alignment (eg FSPCA)
419+ :param alignment_basis: Basis to be used during alignment.
420+ For current implementation of ReddyChatterjiAlign2D this should be `None`.
421+ Instead see `alignment_source`.
438422 :param source: Source of original images.
439- :param composite_basis: Basis to be used during class average composition (eg FFB2D)
423+ :param composite_basis: Basis to be used during class average composition.
424+ For current implementation of ReddyChatterjiAlign2D this should be `None`.
425+ Instead this method uses `source` for composition of the averaged stack.
426+ :param alignment_source: Basis to be used during class average composition.
427+ Must be the same resolution as `source`.
440428 :param dtype: Numpy dtype to be used during alignment.
441429 """
442430
443431 self .__cache = dict ()
444432 self .diagnostics = diagnostics
445433 self .do_cross_corr_translations = True
434+ self .alignment_src = alignment_source or source
435+
436+ # TODO, for accomodating different resolutions we minimally need to adapt shifting.
437+ # Outside of scope right now, but would make a nice PR later.
438+ if self .alignment_src .L != source .L :
439+ raise RuntimeError ("Currently `alignment_src.L` must equal `source.L`" )
440+ if self .alignment_src .dtype != source .dtype :
441+ raise RuntimeError (
442+ "Currently `alignment_src.dtype` must equal `source.dtype`"
443+ )
444+
445+ # Sanity check. This API should be rethought once all basis and
446+ # alignment methods have been incorporated.
447+ assert alignment_basis is None # We use sources directly for alignment
448+ assert (
449+ composite_basis is not None
450+ ) # However, we require a basis for rotating etc.
446451
447452 super ().__init__ (
448453 alignment_basis , source , composite_basis , batch_size = batch_size , dtype = dtype
@@ -497,8 +502,8 @@ def _align(self, classes, reflections, basis_coefficients):
497502 shifts = np .zeros ((* classes .shape , 2 ), dtype = int )
498503
499504 for k in trange (n_classes ):
500- # # Get the array of images for this class
501- images = self ._images (classes [k ])
505+ # # Get the array of images for this class, using the `alignment_src`.
506+ images = self ._images (classes [k ], src = self . alignment_src )
502507
503508 self ._reddychatterji (
504509 k , images , classes , reflections , rotations , correlations , shifts
@@ -817,7 +822,7 @@ def _translation_cross_corr_diagnostic(self, cross_correlation):
817822 plt .imshow (cross_correlation )
818823 plt .xlabel ("x shift (pixels)" )
819824 plt .ylabel ("y shift (pixels)" )
820- L = self .src .L
825+ L = self .alignment_src .L
821826 labels = [0 , 10 , 20 , 30 , 0 , - 10 , - 20 , - 30 ]
822827 tick_location = [0 , 10 , 20 , 30 , L , L - 10 , L - 20 , L - 30 ]
823828 plt .xticks (tick_location , labels )
@@ -875,17 +880,26 @@ def __init__(
875880 alignment_basis ,
876881 source ,
877882 composite_basis = None ,
883+ alignment_source = None ,
878884 radius = None ,
879885 diagnostics = False ,
880886 batch_size = 512 ,
881887 dtype = None ,
882888 ):
883889 """
884- :param alignment_basis: Basis to be used during alignment (eg FSPCA)
890+ :param alignment_basis: Basis to be used during alignment.
891+ For current implementation of ReddyChatterjiAlign2D this should be `None`.
892+ Instead see `alignment_source`.
885893 :param source: Source of original images.
886- :param composite_basis: Basis to be used during class average composition (eg FFB2D)
894+ :param composite_basis: Basis to be used during class average composition.
895+ For current implementation of ReddyChatterjiAlign2D this should be `None`.
896+ Instead this method uses `source` for composition of the averaged stack.
897+ :param alignment_source: Basis to be used during class average composition.
898+ Must be the same resolution as `source`.
887899 :param radius: Brute force translation search radius.
888900 Defaults to source.L//8.
901+ :param dtype: Numpy dtype to be used during alignment.
902+
889903 :param diagnostics: Plot interactive diagnostic graphics (for debugging).
890904 :param dtype: Numpy dtype to be used during alignment.
891905 """
@@ -894,6 +908,7 @@ def __init__(
894908 alignment_basis ,
895909 source ,
896910 composite_basis ,
911+ alignment_source ,
897912 diagnostics ,
898913 batch_size = batch_size ,
899914 dtype = dtype ,
@@ -915,7 +930,7 @@ def _align(self, classes, reflections, basis_coefficients):
915930 reflections = np .atleast_2d (reflections )
916931
917932 n_classes , n_nbor = classes .shape
918- L = self .src .L
933+ L = self .alignment_src .L
919934
920935 # Instantiate matrices for inner loop, and best results.
921936 _rotations = np .zeros (classes .shape , dtype = self .dtype )
0 commit comments