11import logging
2+ from abc import ABC , abstractmethod
23from itertools import product
34
45import numpy as np
5- from tqdm import trange
6+ from tqdm import tqdm , trange
7+
8+ from aspire .image import Image
9+ from aspire .source import ArrayImageSource
610
711logger = logging .getLogger (__name__ )
812
913
10- class Align2D :
14+ class Align2D ( ABC ) :
1115 """
1216 Base class for 2D Image Alignment methods.
1317 """
1418
15- def __init__ (self , basis , dtype ):
19+ def __init__ (self , alignment_basis , source , composite_basis = None , dtype = None ):
1620 """
17- :param basis: Basis to be used for any methods during alignment.
21+ :param alignment_basis: Basis to be used during alignment (eg FSPCA)
22+ :param source: Source of original images.
23+ :param composite_basis: Basis to be used during class average composition (eg FFB2D)
1824 :param dtype: Numpy dtype to be used during alignment.
1925 """
2026
21- self .basis = basis
27+ self .alignment_basis = alignment_basis
28+ # if composite_basis is None, use alignment_basis
29+ self .composite_basis = composite_basis or self .alignment_basis
30+ self .src = source
2231 if dtype is None :
23- self .dtype = self .basis .dtype
32+ self .dtype = self .alignment_basis .dtype
2433 else :
2534 self .dtype = np .dtype (dtype )
26- if self .dtype != self .basis .dtype :
35+ if self .dtype != self .alignment_basis .dtype :
2736 logger .warning (
28- f"Align2D basis .dtype { self .basis .dtype } does not match self.dtype { self .dtype } ."
37+ f"Align2D alignment_basis .dtype { self .alignment_basis .dtype } does not match self.dtype { self .dtype } ."
2938 )
3039
40+ @abstractmethod
3141 def align (self , classes , reflections , basis_coefficients ):
3242 """
3343 Any align2D alignment method should take in the following arguments
34- and return the described tuple .
44+ and return aligned images .
3545
36- Generally, the returned `classes` and `reflections` should be same as
37- the input. They are passed through for convience,
38- considering they would all be required for image output .
46+ During this process `rotations`, `reflections`, `shifts` and
47+ `correlations` propeties will be computed for aligners
48+ that implement them .
3949
40- Returned `rotations` is an (n_classes, n_nbor) array of angles,
50+ `rotations` would be an (n_classes, n_nbor) array of angles,
4151 which should represent the rotations needed to align images within
4252 that class. `rotations` is measured in Radians.
4353
44- Returned `correlations` is an (n_classes, n_nbor) array representing
54+ `correlations` is an (n_classes, n_nbor) array representing
4555 a correlation like measure between classified images and their base
4656 image (image index 0).
4757
48- Returned `shifts` is None or an (n_classes, n_nbor) array of 2D shifts
58+ `shifts` is None or an (n_classes, n_nbor) array of 2D shifts
4959 which should represent the translation needed to best align the images
5060 within that class.
5161
@@ -55,12 +65,79 @@ def align(self, classes, reflections, basis_coefficients):
5565 :param refl: (n_classes, n_nbor) bool array of corresponding reflections
5666 :param coef: (n_img, self.pca_basis.count) compressed basis coefficients
5767
58- :returns: (classes, reflections, rotations, shifts, correlations )
68+ :returns: Image instance (stack of images )
5969 """
60- raise NotImplementedError ("Subclasses must implement align." )
6170
6271
63- class BFRAlign2D (Align2D ):
72+ class AveragedAlign2D (Align2D ):
73+ """
74+ Subclass supporting aligners which perform averaging during output.
75+ """
76+
77+ def align (self , classes , reflections , basis_coefficients ):
78+ """
79+ See Align2D.align
80+ """
81+ # Correlations are currently unused, but left for future extensions.
82+ cls , ref , rot , shf , corrs = self ._align (
83+ classes , reflections , basis_coefficients
84+ )
85+ return self .average (cls , ref , rot , shf ), cls , ref , rot , shf , corrs
86+
87+ def average (
88+ self ,
89+ classes ,
90+ reflections ,
91+ rotations ,
92+ shifts = None ,
93+ coefs = None ,
94+ ):
95+ """
96+ Combines images using averaging in provided `basis`.
97+
98+ :param classes: class indices (refering to src). (n_img, n_nbor)
99+ :param reflections: Bool representing whether to reflect image in `classes`
100+ :param rotations: Array of in-plane rotation angles (Radians) of image in `classes`
101+ :param shifts: Optional array of shifts for image in `classes`.
102+ :coefs: Optional Fourier bessel coefs (avoids recomputing).
103+ :return: Stack of Synthetic Class Average images as Image instance.
104+ """
105+ n_classes , n_nbor = classes .shape
106+
107+ # TODO: don't load all the images here.
108+ imgs = self .src .images (0 , self .src .n )
109+ b_avgs = np .empty ((n_classes , self .composite_basis .count ), dtype = self .src .dtype )
110+
111+ for i in tqdm (range (n_classes )):
112+ # Get the neighbors
113+ neighbors_ids = classes [i ]
114+
115+ # Get coefs in Composite_Basis if not provided as an argument.
116+ if coefs is None :
117+ neighbors_imgs = Image (imgs [neighbors_ids ])
118+ if shifts is not None :
119+ neighbors_imgs .shift (shifts [i ])
120+ neighbors_coefs = self .composite_basis .evaluate_t (neighbors_imgs )
121+ else :
122+ neighbors_coefs = coefs [neighbors_ids ]
123+ if shifts is not None :
124+ neighbors_coefs = self .composite_basis .shift (
125+ neighbors_coefs , shifts [i ]
126+ )
127+
128+ # Rotate in composite_basis
129+ neighbors_coefs = self .composite_basis .rotate (
130+ neighbors_coefs , rotations [i ], reflections [i ]
131+ )
132+
133+ # Averaging in composite_basis
134+ b_avgs [i ] = np .mean (neighbors_coefs , axis = 0 )
135+
136+ # Now we convert the averaged images from Basis to Cartesian.
137+ return ArrayImageSource (self .composite_basis .evaluate (b_avgs ))
138+
139+
140+ class BFRAlign2D (AveragedAlign2D ):
64141 """
65142 This perfoms a Brute Force Rotational alignment.
66143
@@ -69,24 +146,29 @@ class BFRAlign2D(Align2D):
69146 and then identifies angle yielding largest correlation(dot).
70147 """
71148
72- def __init__ (self , basis , n_angles = 359 , dtype = None ):
149+ def __init__ (
150+ self , alignment_basis , source , composite_basis = None , n_angles = 359 , dtype = None
151+ ):
73152 """
74- :params basis: Basis providing a `rotate` method.
153+ :params alignment_basis: Basis providing a `rotate` method.
154+ :param source: Source of original images.
75155 :params n_angles: Number of brute force rotations to attempt, defaults 359.
76156 """
77- super ().__init__ (basis , dtype )
157+ super ().__init__ (alignment_basis , source , composite_basis , dtype )
78158
79159 self .n_angles = n_angles
80160
81- if not hasattr (self .basis , "rotate" ):
161+ if not hasattr (self .alignment_basis , "rotate" ):
82162 raise RuntimeError (
83- f"BFRAlign2D's basis { self .basis } must provide a `rotate` method."
163+ f"BFRAlign2D's alignment_basis { self .alignment_basis } must provide a `rotate` method."
84164 )
85165
86- def align (self , classes , reflections , basis_coefficients ):
166+ def _align (self , classes , reflections , basis_coefficients ):
87167 """
88- See `Align2D.align`
168+ Performs the actual rotational alignment estimation,
169+ returning parameters needed for averaging.
89170 """
171+
90172 # Admit simple case of single case alignment
91173 classes = np .atleast_2d (classes )
92174 reflections = np .atleast_2d (reflections )
@@ -108,7 +190,9 @@ def align(self, classes, reflections, basis_coefficients):
108190
109191 for i , angle in enumerate (test_angles ):
110192 # Rotate the set of neighbors by angle,
111- rotated_nbrs = self .basis .rotate (nbr_coef , angle , reflections [k ])
193+ rotated_nbrs = self .alignment_basis .rotate (
194+ nbr_coef , angle , reflections [k ]
195+ )
112196
113197 # then store dot between class base image (0) and each nbor
114198 for j , nbor in enumerate (rotated_nbrs ):
@@ -124,7 +208,6 @@ def align(self, classes, reflections, basis_coefficients):
124208 for j in range (n_nbor ):
125209 correlations [k , j ] = results [j , angle_idx [j ]]
126210
127- # None is placeholder for shifts
128211 return classes , reflections , rotations , None , correlations
129212
130213
@@ -139,7 +222,16 @@ class BFSRAlign2D(BFRAlign2D):
139222 Return the rotation and shift yielding the best results.
140223 """
141224
142- def __init__ (self , basis , n_angles = 359 , n_x_shifts = 1 , n_y_shifts = 1 , dtype = None ):
225+ def __init__ (
226+ self ,
227+ alignment_basis ,
228+ source ,
229+ composite_basis = None ,
230+ n_angles = 359 ,
231+ n_x_shifts = 1 ,
232+ n_y_shifts = 1 ,
233+ dtype = None ,
234+ ):
143235 """
144236 Note that n_x_shifts and n_y_shifts are the number of shifts to perform
145237 in each direction.
@@ -148,25 +240,25 @@ def __init__(self, basis, n_angles=359, n_x_shifts=1, n_y_shifts=1, dtype=None):
148240
149241 n_x_shifts=n_y_shifts=0 is the same as calling BFRAlign2D.
150242
151- :params basis : Basis providing a `shift` and `rotate` method.
243+ :params alignment_basis : Basis providing a `shift` and `rotate` method.
152244 :params n_angles: Number of brute force rotations to attempt, defaults 359.
153245 :params n_x_shifts: +- Number of brute force xshifts to attempt, defaults 1.
154246 :params n_y_shifts: +- Number of brute force xshifts to attempt, defaults 1.
155247 """
156- super ().__init__ (basis , n_angles , dtype )
248+ super ().__init__ (alignment_basis , source , composite_basis , n_angles , dtype )
157249
158250 self .n_x_shifts = n_x_shifts
159251 self .n_y_shifts = n_y_shifts
160252
161- if not hasattr (self .basis , "shift" ):
253+ if not hasattr (self .alignment_basis , "shift" ):
162254 raise RuntimeError (
163- f"BFSRAlign2D's basis { self .basis } must provide a `shift` method."
255+ f"BFSRAlign2D's alignment_basis { self .alignment_basis } must provide a `shift` method."
164256 )
165257
166- # Each shift will require calling the parent BFRAlign2D.align
167- self ._bfr_align = super ().align
258+ # Each shift will require calling the parent BFRAlign2D._align
259+ self ._bfr_align = super ()._align
168260
169- def align (self , classes , reflections , basis_coefficients ):
261+ def _align (self , classes , reflections , basis_coefficients ):
170262 """
171263 See `Align2D.align`
172264 """
@@ -196,7 +288,7 @@ def align(self, classes, reflections, basis_coefficients):
196288 # We want to maintain the original coefs for the base images,
197289 # because we will mutate them with shifts in the loop.
198290 original_coef = basis_coefficients [classes [:, 0 ], :]
199- assert original_coef .shape == (n_classes , self .basis .count )
291+ assert original_coef .shape == (n_classes , self .alignment_basis .count )
200292
201293 # Loop over shift search space, updating best result
202294 for x , y in product (x_shifts , y_shifts ):
@@ -206,7 +298,7 @@ def align(self, classes, reflections, basis_coefficients):
206298 # Shift the coef representing the first (base) entry in each class
207299 # by the negation of the shift
208300 # Shifting one image is more efficient than shifting every neighbor
209- basis_coefficients [classes [:, 0 ], :] = self .basis .shift (
301+ basis_coefficients [classes [:, 0 ], :] = self .alignment_basis .shift (
210302 original_coef , - shift
211303 )
212304
@@ -242,18 +334,9 @@ class EMAlign2D(Align2D):
242334 Citation needed.
243335 """
244336
245- def __init__ (self , basis , dtype = None ):
246- super ().__init__ (basis , dtype )
247-
248337
249338class FTKAlign2D (Align2D ):
250339 """
251340 Factorization of the translation kernel for fast rigid image alignment.
252341 Rangan, A.V., Spivak, M., Anden, J., & Barnett, A.H. (2019).
253342 """
254-
255- def __init__ (self , basis , dtype = None ):
256- super ().__init__ (basis , dtype )
257-
258- def align (self , classes , reflections , basis_coefficients ):
259- raise NotImplementedError
0 commit comments