Skip to content

Commit 6b6469e

Browse files
committed
Refactor align2D to return stack of cls avg
this better fits future codes like EM
1 parent 0729291 commit 6b6469e

File tree

6 files changed

+216
-130
lines changed

6 files changed

+216
-130
lines changed
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1-
from .align2d import Align2D, BFRAlign2D, BFSRAlign2D, EMAlign2D, FTKAlign2D
1+
from .align2d import (
2+
Align2D,
3+
AveragedAlign2D,
4+
BFRAlign2D,
5+
BFSRAlign2D,
6+
EMAlign2D,
7+
FTKAlign2D,
8+
)
29
from .class2d import Class2D
310
from .rir_class2d import RIRClass2D

src/aspire/classification/align2d.py

Lines changed: 129 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,61 @@
11
import logging
2+
from abc import ABC, abstractmethod
23
from itertools import product
34

45
import numpy as np
5-
from tqdm import trange
6+
from tqdm import tqdm, trange
7+
8+
from aspire.image import Image
9+
from aspire.source import ArrayImageSource
610

711
logger = logging.getLogger(__name__)
812

913

10-
class Align2D:
14+
class Align2D(ABC):
1115
"""
1216
Base class for 2D Image Alignment methods.
1317
"""
1418

15-
def __init__(self, basis, dtype):
19+
def __init__(self, alignment_basis, source, composite_basis=None, dtype=None):
1620
"""
17-
:param basis: Basis to be used for any methods during alignment.
21+
:param alignment_basis: Basis to be used during alignment (eg FSPCA)
22+
:param source: Source of original images.
23+
:param composite_basis: Basis to be used during class average composition (eg FFB2D)
1824
:param dtype: Numpy dtype to be used during alignment.
1925
"""
2026

21-
self.basis = basis
27+
self.alignment_basis = alignment_basis
28+
# if composite_basis is None, use alignment_basis
29+
self.composite_basis = composite_basis or self.alignment_basis
30+
self.src = source
2231
if dtype is None:
23-
self.dtype = self.basis.dtype
32+
self.dtype = self.alignment_basis.dtype
2433
else:
2534
self.dtype = np.dtype(dtype)
26-
if self.dtype != self.basis.dtype:
35+
if self.dtype != self.alignment_basis.dtype:
2736
logger.warning(
28-
f"Align2D basis.dtype {self.basis.dtype} does not match self.dtype {self.dtype}."
37+
f"Align2D alignment_basis.dtype {self.alignment_basis.dtype} does not match self.dtype {self.dtype}."
2938
)
3039

40+
@abstractmethod
3141
def align(self, classes, reflections, basis_coefficients):
3242
"""
3343
Any align2D alignment method should take in the following arguments
34-
and return the described tuple.
44+
and return aligned images.
3545
36-
Generally, the returned `classes` and `reflections` should be same as
37-
the input. They are passed through for convience,
38-
considering they would all be required for image output.
46+
During this process `rotations`, `reflections`, `shifts` and
47+
`correlations` propeties will be computed for aligners
48+
that implement them.
3949
40-
Returned `rotations` is an (n_classes, n_nbor) array of angles,
50+
`rotations` would be an (n_classes, n_nbor) array of angles,
4151
which should represent the rotations needed to align images within
4252
that class. `rotations` is measured in Radians.
4353
44-
Returned `correlations` is an (n_classes, n_nbor) array representing
54+
`correlations` is an (n_classes, n_nbor) array representing
4555
a correlation like measure between classified images and their base
4656
image (image index 0).
4757
48-
Returned `shifts` is None or an (n_classes, n_nbor) array of 2D shifts
58+
`shifts` is None or an (n_classes, n_nbor) array of 2D shifts
4959
which should represent the translation needed to best align the images
5060
within that class.
5161
@@ -55,12 +65,79 @@ def align(self, classes, reflections, basis_coefficients):
5565
:param refl: (n_classes, n_nbor) bool array of corresponding reflections
5666
:param coef: (n_img, self.pca_basis.count) compressed basis coefficients
5767
58-
:returns: (classes, reflections, rotations, shifts, correlations)
68+
:returns: Image instance (stack of images)
5969
"""
60-
raise NotImplementedError("Subclasses must implement align.")
6170

6271

63-
class BFRAlign2D(Align2D):
72+
class AveragedAlign2D(Align2D):
73+
"""
74+
Subclass supporting aligners which perform averaging during output.
75+
"""
76+
77+
def align(self, classes, reflections, basis_coefficients):
78+
"""
79+
See Align2D.align
80+
"""
81+
# Correlations are currently unused, but left for future extensions.
82+
cls, ref, rot, shf, corrs = self._align(
83+
classes, reflections, basis_coefficients
84+
)
85+
return self.average(cls, ref, rot, shf), cls, ref, rot, shf, corrs
86+
87+
def average(
88+
self,
89+
classes,
90+
reflections,
91+
rotations,
92+
shifts=None,
93+
coefs=None,
94+
):
95+
"""
96+
Combines images using averaging in provided `basis`.
97+
98+
:param classes: class indices (refering to src). (n_img, n_nbor)
99+
:param reflections: Bool representing whether to reflect image in `classes`
100+
:param rotations: Array of in-plane rotation angles (Radians) of image in `classes`
101+
:param shifts: Optional array of shifts for image in `classes`.
102+
:coefs: Optional Fourier bessel coefs (avoids recomputing).
103+
:return: Stack of Synthetic Class Average images as Image instance.
104+
"""
105+
n_classes, n_nbor = classes.shape
106+
107+
# TODO: don't load all the images here.
108+
imgs = self.src.images(0, self.src.n)
109+
b_avgs = np.empty((n_classes, self.composite_basis.count), dtype=self.src.dtype)
110+
111+
for i in tqdm(range(n_classes)):
112+
# Get the neighbors
113+
neighbors_ids = classes[i]
114+
115+
# Get coefs in Composite_Basis if not provided as an argument.
116+
if coefs is None:
117+
neighbors_imgs = Image(imgs[neighbors_ids])
118+
if shifts is not None:
119+
neighbors_imgs.shift(shifts[i])
120+
neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs)
121+
else:
122+
neighbors_coefs = coefs[neighbors_ids]
123+
if shifts is not None:
124+
neighbors_coefs = self.composite_basis.shift(
125+
neighbors_coefs, shifts[i]
126+
)
127+
128+
# Rotate in composite_basis
129+
neighbors_coefs = self.composite_basis.rotate(
130+
neighbors_coefs, rotations[i], reflections[i]
131+
)
132+
133+
# Averaging in composite_basis
134+
b_avgs[i] = np.mean(neighbors_coefs, axis=0)
135+
136+
# Now we convert the averaged images from Basis to Cartesian.
137+
return ArrayImageSource(self.composite_basis.evaluate(b_avgs))
138+
139+
140+
class BFRAlign2D(AveragedAlign2D):
64141
"""
65142
This perfoms a Brute Force Rotational alignment.
66143
@@ -69,24 +146,29 @@ class BFRAlign2D(Align2D):
69146
and then identifies angle yielding largest correlation(dot).
70147
"""
71148

72-
def __init__(self, basis, n_angles=359, dtype=None):
149+
def __init__(
150+
self, alignment_basis, source, composite_basis=None, n_angles=359, dtype=None
151+
):
73152
"""
74-
:params basis: Basis providing a `rotate` method.
153+
:params alignment_basis: Basis providing a `rotate` method.
154+
:param source: Source of original images.
75155
:params n_angles: Number of brute force rotations to attempt, defaults 359.
76156
"""
77-
super().__init__(basis, dtype)
157+
super().__init__(alignment_basis, source, composite_basis, dtype)
78158

79159
self.n_angles = n_angles
80160

81-
if not hasattr(self.basis, "rotate"):
161+
if not hasattr(self.alignment_basis, "rotate"):
82162
raise RuntimeError(
83-
f"BFRAlign2D's basis {self.basis} must provide a `rotate` method."
163+
f"BFRAlign2D's alignment_basis {self.alignment_basis} must provide a `rotate` method."
84164
)
85165

86-
def align(self, classes, reflections, basis_coefficients):
166+
def _align(self, classes, reflections, basis_coefficients):
87167
"""
88-
See `Align2D.align`
168+
Performs the actual rotational alignment estimation,
169+
returning parameters needed for averaging.
89170
"""
171+
90172
# Admit simple case of single case alignment
91173
classes = np.atleast_2d(classes)
92174
reflections = np.atleast_2d(reflections)
@@ -108,7 +190,9 @@ def align(self, classes, reflections, basis_coefficients):
108190

109191
for i, angle in enumerate(test_angles):
110192
# Rotate the set of neighbors by angle,
111-
rotated_nbrs = self.basis.rotate(nbr_coef, angle, reflections[k])
193+
rotated_nbrs = self.alignment_basis.rotate(
194+
nbr_coef, angle, reflections[k]
195+
)
112196

113197
# then store dot between class base image (0) and each nbor
114198
for j, nbor in enumerate(rotated_nbrs):
@@ -124,7 +208,6 @@ def align(self, classes, reflections, basis_coefficients):
124208
for j in range(n_nbor):
125209
correlations[k, j] = results[j, angle_idx[j]]
126210

127-
# None is placeholder for shifts
128211
return classes, reflections, rotations, None, correlations
129212

130213

@@ -139,7 +222,16 @@ class BFSRAlign2D(BFRAlign2D):
139222
Return the rotation and shift yielding the best results.
140223
"""
141224

142-
def __init__(self, basis, n_angles=359, n_x_shifts=1, n_y_shifts=1, dtype=None):
225+
def __init__(
226+
self,
227+
alignment_basis,
228+
source,
229+
composite_basis=None,
230+
n_angles=359,
231+
n_x_shifts=1,
232+
n_y_shifts=1,
233+
dtype=None,
234+
):
143235
"""
144236
Note that n_x_shifts and n_y_shifts are the number of shifts to perform
145237
in each direction.
@@ -148,25 +240,25 @@ def __init__(self, basis, n_angles=359, n_x_shifts=1, n_y_shifts=1, dtype=None):
148240
149241
n_x_shifts=n_y_shifts=0 is the same as calling BFRAlign2D.
150242
151-
:params basis: Basis providing a `shift` and `rotate` method.
243+
:params alignment_basis: Basis providing a `shift` and `rotate` method.
152244
:params n_angles: Number of brute force rotations to attempt, defaults 359.
153245
:params n_x_shifts: +- Number of brute force xshifts to attempt, defaults 1.
154246
:params n_y_shifts: +- Number of brute force xshifts to attempt, defaults 1.
155247
"""
156-
super().__init__(basis, n_angles, dtype)
248+
super().__init__(alignment_basis, source, composite_basis, n_angles, dtype)
157249

158250
self.n_x_shifts = n_x_shifts
159251
self.n_y_shifts = n_y_shifts
160252

161-
if not hasattr(self.basis, "shift"):
253+
if not hasattr(self.alignment_basis, "shift"):
162254
raise RuntimeError(
163-
f"BFSRAlign2D's basis {self.basis} must provide a `shift` method."
255+
f"BFSRAlign2D's alignment_basis {self.alignment_basis} must provide a `shift` method."
164256
)
165257

166-
# Each shift will require calling the parent BFRAlign2D.align
167-
self._bfr_align = super().align
258+
# Each shift will require calling the parent BFRAlign2D._align
259+
self._bfr_align = super()._align
168260

169-
def align(self, classes, reflections, basis_coefficients):
261+
def _align(self, classes, reflections, basis_coefficients):
170262
"""
171263
See `Align2D.align`
172264
"""
@@ -196,7 +288,7 @@ def align(self, classes, reflections, basis_coefficients):
196288
# We want to maintain the original coefs for the base images,
197289
# because we will mutate them with shifts in the loop.
198290
original_coef = basis_coefficients[classes[:, 0], :]
199-
assert original_coef.shape == (n_classes, self.basis.count)
291+
assert original_coef.shape == (n_classes, self.alignment_basis.count)
200292

201293
# Loop over shift search space, updating best result
202294
for x, y in product(x_shifts, y_shifts):
@@ -206,7 +298,7 @@ def align(self, classes, reflections, basis_coefficients):
206298
# Shift the coef representing the first (base) entry in each class
207299
# by the negation of the shift
208300
# Shifting one image is more efficient than shifting every neighbor
209-
basis_coefficients[classes[:, 0], :] = self.basis.shift(
301+
basis_coefficients[classes[:, 0], :] = self.alignment_basis.shift(
210302
original_coef, -shift
211303
)
212304

@@ -242,18 +334,9 @@ class EMAlign2D(Align2D):
242334
Citation needed.
243335
"""
244336

245-
def __init__(self, basis, dtype=None):
246-
super().__init__(basis, dtype)
247-
248337

249338
class FTKAlign2D(Align2D):
250339
"""
251340
Factorization of the translation kernel for fast rigid image alignment.
252341
Rangan, A.V., Spivak, M., Anden, J., & Barnett, A.H. (2019).
253342
"""
254-
255-
def __init__(self, basis, dtype=None):
256-
super().__init__(basis, dtype)
257-
258-
def align(self, classes, reflections, basis_coefficients):
259-
raise NotImplementedError

src/aspire/classification/class2d.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import logging
2+
from abc import ABC
23

34
import numpy as np
45

56
logger = logging.getLogger(__name__)
67

78

8-
class Class2D:
9+
class Class2D(ABC):
910
"""
1011
Base class for 2D Image Classification methods.
1112
"""
@@ -41,3 +42,15 @@ def __init__(
4142
self.n_nbor = n_nbor
4243
self.n_classes = n_classes
4344
self.seed = seed
45+
46+
def classify(self):
47+
"""
48+
Classify the images from Source into classes with similar viewing angles.
49+
50+
Returns classes and associated metadata (classes, reflections, distances)
51+
"""
52+
53+
def averages(self, classes, refl, distances):
54+
"""
55+
Returns class averages using prescribed `aligner`.
56+
"""

0 commit comments

Comments
 (0)