Skip to content

Commit fec9f22

Browse files
committed
Rough in ability to use seperate alignment and composition sources for RC aligners.
1 parent d908119 commit fec9f22

File tree

1 file changed

+53
-38
lines changed

1 file changed

+53
-38
lines changed

src/aspire/classification/align2d.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
"""
2929
:param alignment_basis: Basis to be used during alignment (eg FSPCA)
3030
:param source: Source of original images.
31-
:param composite_basis: Basis to be used during class average composition (eg FFB2D)
31+
:param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D)
3232
:param dtype: Numpy dtype to be used during alignment.
3333
"""
3434

@@ -89,43 +89,24 @@ def align(self, classes, reflections, basis_coefficients):
8989
:returns: Image instance (stack of images)
9090
"""
9191

92-
def _images(self, cls):
92+
def _images(self, cls, src=None):
9393
"""
9494
Util to return images as an array for class k (provided as array `cls` ),
9595
preserving the class/nbor order.
9696
97-
:param cls: An iterable (0/1-D array or list) that holds the indices of images to align. In Class Averaging, this would be a class.
97+
:param cls: An iterable (0/1-D array or list) that holds the indices of images to align.
98+
In Class Averaging, this would be a class.
99+
:param src: Optionally overridee the src, for example, if you want to use a different
100+
source for a certain operation (ie aignment).
98101
"""
102+
src = src or self.src
99103

100104
n_nbor = cls.shape[-1] # Includes zero'th neighbor
101105

102-
# Get the images. We'll loop over the source in batches.
103-
# Note one day when the Source.images is more flexible,
104-
# this code would mostly go away.
105-
images = np.empty((n_nbor, self.src.L, self.src.L), dtype=self.dtype)
106-
107-
# We want to only process batches that actually
108-
# contain images for this class.
109-
# First compute the batches' indices.
110-
for start in range(0, self.src.n + 1, self.batch_size):
111-
# First cook up the batch boundaries
112-
end = start + self.batch_size
113-
# UBound, these are inclusive bounds
114-
start = min(start, self.src.n - 1)
115-
end = min(end, self.src.n - 1)
116-
num = end - start + 1
117-
118-
# Second, loop over the cls members
119-
image_batch = None
120-
for i, index in enumerate(cls):
121-
# Check if the member is in this chunk
122-
if start <= index <= end:
123-
# Get and cache this image_batch on first hit.
124-
if image_batch is None:
125-
image_batch = self.src.images(start, num)
126-
# Translate the cls's index into this batch's
127-
batch_index = index % self.batch_size
128-
images[i] = image_batch[batch_index]
106+
images = np.empty((n_nbor, src.L, src.L), dtype=self.dtype)
107+
108+
for i, index in enumerate(cls):
109+
images[i] = src.images(index, 1).asnumpy()
129110

130111
return images
131112

@@ -429,20 +410,44 @@ def __init__(
429410
alignment_basis,
430411
source,
431412
composite_basis=None,
413+
alignment_source=None,
432414
diagnostics=False,
433415
batch_size=512,
434416
dtype=None,
435417
):
436418
"""
437-
:param alignment_basis: Basis to be used during alignment (eg FSPCA)
419+
:param alignment_basis: Basis to be used during alignment.
420+
For current implementation of ReddyChatterjiAlign2D this should be `None`.
421+
Instead see `alignment_source`.
438422
:param source: Source of original images.
439-
:param composite_basis: Basis to be used during class average composition (eg FFB2D)
423+
:param composite_basis: Basis to be used during class average composition.
424+
For current implementation of ReddyChatterjiAlign2D this should be `None`.
425+
Instead this method uses `source` for composition of the averaged stack.
426+
:param alignment_source: Basis to be used during class average composition.
427+
Must be the same resolution as `source`.
440428
:param dtype: Numpy dtype to be used during alignment.
441429
"""
442430

443431
self.__cache = dict()
444432
self.diagnostics = diagnostics
445433
self.do_cross_corr_translations = True
434+
self.alignment_src = alignment_source or source
435+
436+
# TODO, for accomodating different resolutions we minimally need to adapt shifting.
437+
# Outside of scope right now, but would make a nice PR later.
438+
if self.alignment_src.L != source.L:
439+
raise RuntimeError("Currently `alignment_src.L` must equal `source.L`")
440+
if self.alignment_src.dtype != source.dtype:
441+
raise RuntimeError(
442+
"Currently `alignment_src.dtype` must equal `source.dtype`"
443+
)
444+
445+
# Sanity check. This API should be rethought once all basis and
446+
# alignment methods have been incorporated.
447+
assert alignment_basis is None # We use sources directly for alignment
448+
assert (
449+
composite_basis is not None
450+
) # However, we require a basis for rotating etc.
446451

447452
super().__init__(
448453
alignment_basis, source, composite_basis, batch_size=batch_size, dtype=dtype
@@ -497,8 +502,8 @@ def _align(self, classes, reflections, basis_coefficients):
497502
shifts = np.zeros((*classes.shape, 2), dtype=int)
498503

499504
for k in trange(n_classes):
500-
# # Get the array of images for this class
501-
images = self._images(classes[k])
505+
# # Get the array of images for this class, using the `alignment_src`.
506+
images = self._images(classes[k], src=self.alignment_src)
502507

503508
self._reddychatterji(
504509
k, images, classes, reflections, rotations, correlations, shifts
@@ -817,7 +822,7 @@ def _translation_cross_corr_diagnostic(self, cross_correlation):
817822
plt.imshow(cross_correlation)
818823
plt.xlabel("x shift (pixels)")
819824
plt.ylabel("y shift (pixels)")
820-
L = self.src.L
825+
L = self.alignment_src.L
821826
labels = [0, 10, 20, 30, 0, -10, -20, -30]
822827
tick_location = [0, 10, 20, 30, L, L - 10, L - 20, L - 30]
823828
plt.xticks(tick_location, labels)
@@ -875,17 +880,26 @@ def __init__(
875880
alignment_basis,
876881
source,
877882
composite_basis=None,
883+
alignment_source=None,
878884
radius=None,
879885
diagnostics=False,
880886
batch_size=512,
881887
dtype=None,
882888
):
883889
"""
884-
:param alignment_basis: Basis to be used during alignment (eg FSPCA)
890+
:param alignment_basis: Basis to be used during alignment.
891+
For current implementation of ReddyChatterjiAlign2D this should be `None`.
892+
Instead see `alignment_source`.
885893
:param source: Source of original images.
886-
:param composite_basis: Basis to be used during class average composition (eg FFB2D)
894+
:param composite_basis: Basis to be used during class average composition.
895+
For current implementation of ReddyChatterjiAlign2D this should be `None`.
896+
Instead this method uses `source` for composition of the averaged stack.
897+
:param alignment_source: Basis to be used during class average composition.
898+
Must be the same resolution as `source`.
887899
:param radius: Brute force translation search radius.
888900
Defaults to source.L//8.
901+
:param dtype: Numpy dtype to be used during alignment.
902+
889903
:param diagnostics: Plot interactive diagnostic graphics (for debugging).
890904
:param dtype: Numpy dtype to be used during alignment.
891905
"""
@@ -894,6 +908,7 @@ def __init__(
894908
alignment_basis,
895909
source,
896910
composite_basis,
911+
alignment_source,
897912
diagnostics,
898913
batch_size=batch_size,
899914
dtype=dtype,
@@ -915,7 +930,7 @@ def _align(self, classes, reflections, basis_coefficients):
915930
reflections = np.atleast_2d(reflections)
916931

917932
n_classes, n_nbor = classes.shape
918-
L = self.src.L
933+
L = self.alignment_src.L
919934

920935
# Instantiate matrices for inner loop, and best results.
921936
_rotations = np.zeros(classes.shape, dtype=self.dtype)

0 commit comments

Comments
 (0)