From ea84bb040d0cfebdfc47b6a5e2432a6241ee5100 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 29 Nov 2021 13:27:40 -0500 Subject: [PATCH 01/40] Refactor align2D to return stack of cls avg this better fits future codes like EM --- src/aspire/classification/__init__.py | 9 +- src/aspire/classification/align2d.py | 175 +++++++++++++++++------ src/aspire/classification/class2d.py | 15 +- src/aspire/classification/rir_class2d.py | 90 ++++-------- tests/test_align2d.py | 37 +++-- tests/test_class2D.py | 20 ++- 6 files changed, 216 insertions(+), 130 deletions(-) diff --git a/src/aspire/classification/__init__.py b/src/aspire/classification/__init__.py index f11ad1d8e0..b7ca5ef5f7 100644 --- a/src/aspire/classification/__init__.py +++ b/src/aspire/classification/__init__.py @@ -1,3 +1,10 @@ -from .align2d import Align2D, BFRAlign2D, BFSRAlign2D, EMAlign2D, FTKAlign2D +from .align2d import ( + Align2D, + AveragedAlign2D, + BFRAlign2D, + BFSRAlign2D, + EMAlign2D, + FTKAlign2D, +) from .class2d import Class2D from .rir_class2d import RIRClass2D diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index bfada60c87..7cf800cbf6 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -1,51 +1,61 @@ import logging +from abc import ABC, abstractmethod from itertools import product import numpy as np -from tqdm import trange +from tqdm import tqdm, trange + +from aspire.image import Image +from aspire.source import ArrayImageSource logger = logging.getLogger(__name__) -class Align2D: +class Align2D(ABC): """ Base class for 2D Image Alignment methods. """ - def __init__(self, basis, dtype): + def __init__(self, alignment_basis, source, composite_basis=None, dtype=None): """ - :param basis: Basis to be used for any methods during alignment. + :param alignment_basis: Basis to be used during alignment (eg FSPCA) + :param source: Source of original images. + :param composite_basis: Basis to be used during class average composition (eg FFB2D) :param dtype: Numpy dtype to be used during alignment. """ - self.basis = basis + self.alignment_basis = alignment_basis + # if composite_basis is None, use alignment_basis + self.composite_basis = composite_basis or self.alignment_basis + self.src = source if dtype is None: - self.dtype = self.basis.dtype + self.dtype = self.alignment_basis.dtype else: self.dtype = np.dtype(dtype) - if self.dtype != self.basis.dtype: + if self.dtype != self.alignment_basis.dtype: logger.warning( - f"Align2D basis.dtype {self.basis.dtype} does not match self.dtype {self.dtype}." + f"Align2D alignment_basis.dtype {self.alignment_basis.dtype} does not match self.dtype {self.dtype}." ) + @abstractmethod def align(self, classes, reflections, basis_coefficients): """ Any align2D alignment method should take in the following arguments - and return the described tuple. + and return aligned images. - Generally, the returned `classes` and `reflections` should be same as - the input. They are passed through for convience, - considering they would all be required for image output. + During this process `rotations`, `reflections`, `shifts` and + `correlations` propeties will be computed for aligners + that implement them. - Returned `rotations` is an (n_classes, n_nbor) array of angles, + `rotations` would be an (n_classes, n_nbor) array of angles, which should represent the rotations needed to align images within that class. `rotations` is measured in Radians. - Returned `correlations` is an (n_classes, n_nbor) array representing + `correlations` is an (n_classes, n_nbor) array representing a correlation like measure between classified images and their base image (image index 0). - Returned `shifts` is None or an (n_classes, n_nbor) array of 2D shifts + `shifts` is None or an (n_classes, n_nbor) array of 2D shifts which should represent the translation needed to best align the images within that class. @@ -55,12 +65,79 @@ def align(self, classes, reflections, basis_coefficients): :param refl: (n_classes, n_nbor) bool array of corresponding reflections :param coef: (n_img, self.pca_basis.count) compressed basis coefficients - :returns: (classes, reflections, rotations, shifts, correlations) + :returns: Image instance (stack of images) """ - raise NotImplementedError("Subclasses must implement align.") -class BFRAlign2D(Align2D): +class AveragedAlign2D(Align2D): + """ + Subclass supporting aligners which perform averaging during output. + """ + + def align(self, classes, reflections, basis_coefficients): + """ + See Align2D.align + """ + # Correlations are currently unused, but left for future extensions. + cls, ref, rot, shf, corrs = self._align( + classes, reflections, basis_coefficients + ) + return self.average(cls, ref, rot, shf), cls, ref, rot, shf, corrs + + def average( + self, + classes, + reflections, + rotations, + shifts=None, + coefs=None, + ): + """ + Combines images using averaging in provided `basis`. + + :param classes: class indices (refering to src). (n_img, n_nbor) + :param reflections: Bool representing whether to reflect image in `classes` + :param rotations: Array of in-plane rotation angles (Radians) of image in `classes` + :param shifts: Optional array of shifts for image in `classes`. + :coefs: Optional Fourier bessel coefs (avoids recomputing). + :return: Stack of Synthetic Class Average images as Image instance. + """ + n_classes, n_nbor = classes.shape + + # TODO: don't load all the images here. + imgs = self.src.images(0, self.src.n) + b_avgs = np.empty((n_classes, self.composite_basis.count), dtype=self.src.dtype) + + for i in tqdm(range(n_classes)): + # Get the neighbors + neighbors_ids = classes[i] + + # Get coefs in Composite_Basis if not provided as an argument. + if coefs is None: + neighbors_imgs = Image(imgs[neighbors_ids]) + if shifts is not None: + neighbors_imgs.shift(shifts[i]) + neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) + else: + neighbors_coefs = coefs[neighbors_ids] + if shifts is not None: + neighbors_coefs = self.composite_basis.shift( + neighbors_coefs, shifts[i] + ) + + # Rotate in composite_basis + neighbors_coefs = self.composite_basis.rotate( + neighbors_coefs, rotations[i], reflections[i] + ) + + # Averaging in composite_basis + b_avgs[i] = np.mean(neighbors_coefs, axis=0) + + # Now we convert the averaged images from Basis to Cartesian. + return ArrayImageSource(self.composite_basis.evaluate(b_avgs)) + + +class BFRAlign2D(AveragedAlign2D): """ This perfoms a Brute Force Rotational alignment. @@ -69,24 +146,29 @@ class BFRAlign2D(Align2D): and then identifies angle yielding largest correlation(dot). """ - def __init__(self, basis, n_angles=359, dtype=None): + def __init__( + self, alignment_basis, source, composite_basis=None, n_angles=359, dtype=None + ): """ - :params basis: Basis providing a `rotate` method. + :params alignment_basis: Basis providing a `rotate` method. + :param source: Source of original images. :params n_angles: Number of brute force rotations to attempt, defaults 359. """ - super().__init__(basis, dtype) + super().__init__(alignment_basis, source, composite_basis, dtype) self.n_angles = n_angles - if not hasattr(self.basis, "rotate"): + if not hasattr(self.alignment_basis, "rotate"): raise RuntimeError( - f"BFRAlign2D's basis {self.basis} must provide a `rotate` method." + f"BFRAlign2D's alignment_basis {self.alignment_basis} must provide a `rotate` method." ) - def align(self, classes, reflections, basis_coefficients): + def _align(self, classes, reflections, basis_coefficients): """ - See `Align2D.align` + Performs the actual rotational alignment estimation, + returning parameters needed for averaging. """ + # Admit simple case of single case alignment classes = np.atleast_2d(classes) reflections = np.atleast_2d(reflections) @@ -108,7 +190,9 @@ def align(self, classes, reflections, basis_coefficients): for i, angle in enumerate(test_angles): # Rotate the set of neighbors by angle, - rotated_nbrs = self.basis.rotate(nbr_coef, angle, reflections[k]) + rotated_nbrs = self.alignment_basis.rotate( + nbr_coef, angle, reflections[k] + ) # then store dot between class base image (0) and each nbor for j, nbor in enumerate(rotated_nbrs): @@ -124,7 +208,6 @@ def align(self, classes, reflections, basis_coefficients): for j in range(n_nbor): correlations[k, j] = results[j, angle_idx[j]] - # None is placeholder for shifts return classes, reflections, rotations, None, correlations @@ -139,7 +222,16 @@ class BFSRAlign2D(BFRAlign2D): Return the rotation and shift yielding the best results. """ - def __init__(self, basis, n_angles=359, n_x_shifts=1, n_y_shifts=1, dtype=None): + def __init__( + self, + alignment_basis, + source, + composite_basis=None, + n_angles=359, + n_x_shifts=1, + n_y_shifts=1, + dtype=None, + ): """ Note that n_x_shifts and n_y_shifts are the number of shifts to perform in each direction. @@ -148,25 +240,25 @@ def __init__(self, basis, n_angles=359, n_x_shifts=1, n_y_shifts=1, dtype=None): n_x_shifts=n_y_shifts=0 is the same as calling BFRAlign2D. - :params basis: Basis providing a `shift` and `rotate` method. + :params alignment_basis: Basis providing a `shift` and `rotate` method. :params n_angles: Number of brute force rotations to attempt, defaults 359. :params n_x_shifts: +- Number of brute force xshifts to attempt, defaults 1. :params n_y_shifts: +- Number of brute force xshifts to attempt, defaults 1. """ - super().__init__(basis, n_angles, dtype) + super().__init__(alignment_basis, source, composite_basis, n_angles, dtype) self.n_x_shifts = n_x_shifts self.n_y_shifts = n_y_shifts - if not hasattr(self.basis, "shift"): + if not hasattr(self.alignment_basis, "shift"): raise RuntimeError( - f"BFSRAlign2D's basis {self.basis} must provide a `shift` method." + f"BFSRAlign2D's alignment_basis {self.alignment_basis} must provide a `shift` method." ) - # Each shift will require calling the parent BFRAlign2D.align - self._bfr_align = super().align + # Each shift will require calling the parent BFRAlign2D._align + self._bfr_align = super()._align - def align(self, classes, reflections, basis_coefficients): + def _align(self, classes, reflections, basis_coefficients): """ See `Align2D.align` """ @@ -196,7 +288,7 @@ def align(self, classes, reflections, basis_coefficients): # We want to maintain the original coefs for the base images, # because we will mutate them with shifts in the loop. original_coef = basis_coefficients[classes[:, 0], :] - assert original_coef.shape == (n_classes, self.basis.count) + assert original_coef.shape == (n_classes, self.alignment_basis.count) # Loop over shift search space, updating best result for x, y in product(x_shifts, y_shifts): @@ -206,7 +298,7 @@ def align(self, classes, reflections, basis_coefficients): # Shift the coef representing the first (base) entry in each class # by the negation of the shift # Shifting one image is more efficient than shifting every neighbor - basis_coefficients[classes[:, 0], :] = self.basis.shift( + basis_coefficients[classes[:, 0], :] = self.alignment_basis.shift( original_coef, -shift ) @@ -242,18 +334,9 @@ class EMAlign2D(Align2D): Citation needed. """ - def __init__(self, basis, dtype=None): - super().__init__(basis, dtype) - class FTKAlign2D(Align2D): """ Factorization of the translation kernel for fast rigid image alignment. Rangan, A.V., Spivak, M., Anden, J., & Barnett, A.H. (2019). """ - - def __init__(self, basis, dtype=None): - super().__init__(basis, dtype) - - def align(self, classes, reflections, basis_coefficients): - raise NotImplementedError diff --git a/src/aspire/classification/class2d.py b/src/aspire/classification/class2d.py index 6544b78ffc..67293031a3 100644 --- a/src/aspire/classification/class2d.py +++ b/src/aspire/classification/class2d.py @@ -1,11 +1,12 @@ import logging +from abc import ABC import numpy as np logger = logging.getLogger(__name__) -class Class2D: +class Class2D(ABC): """ Base class for 2D Image Classification methods. """ @@ -41,3 +42,15 @@ def __init__( self.n_nbor = n_nbor self.n_classes = n_classes self.seed = seed + + def classify(self): + """ + Classify the images from Source into classes with similar viewing angles. + + Returns classes and associated metadata (classes, reflections, distances) + """ + + def averages(self, classes, refl, distances): + """ + Returns class averages using prescribed `aligner`. + """ diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index a799ec7c27..83755c83ee 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -165,7 +165,6 @@ def classify(self, diagnostics=False): # default of 400 components was taken from legacy code. # Instantiate a new compressed (truncated) basis. if self.pca_basis is None: - # self.pca_basis = self.pca_basis.compress(self.fspca_components) self.pca_basis = FSPCABasis(self.src, components=self.fspca_components) # For convenience, assign the fb_basis used in the pca_basis. @@ -174,7 +173,9 @@ def classify(self, diagnostics=False): # When not provided by a user, the aligner is instantiated after # we are certain our pca_basis has been constructed. if self.aligner is None: - self.aligner = BFRAlign2D(self.pca_basis, dtype=self.dtype) + self.aligner = BFRAlign2D( + self.pca_basis, self.src, self.fb_basis, dtype=self.dtype + ) # Get the expanded coefs in the compressed FSPCA space. self.fspca_coef = self.pca_basis.spca_coef @@ -184,7 +185,7 @@ def classify(self, diagnostics=False): # # Stage 2: Compute Nearest Neighbors logger.info("Calculate Nearest Neighbors") - classes, refl, distances = self.nn_classification(coef_b, coef_b_r) + classes, reflections, distances = self.nn_classification(coef_b, coef_b_r) if diagnostics: # Lets peek at the distribution of distances @@ -193,8 +194,14 @@ def classify(self, diagnostics=False): plt.show() # Report some information about reflections - logger.info(f"Count reflected: {np.sum(refl)}" f" {100 * np.mean(refl) } %") + logger.info( + f"Count reflected: {np.sum(reflections)}" + f" {100 * np.mean(reflections) } %" + ) + return classes, reflections, distances + + def averages(self, classes, reflections, distances): # # Stage 3: Class Selection logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") # This is an area open to active research. @@ -206,11 +213,21 @@ def classify(self, diagnostics=False): logger.info( f"Begin Rotational Alignment of {classes.shape[0]} Classes using {self.aligner}." ) - if not self.aligner.basis == self.pca_basis: - raise RuntimeError( - f"Aligner {self.aligner} basis does not match FSPCA basis." - ) - return self.aligner.align(classes, refl, self.fspca_coef) + + logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") + classes, reflections = self.select_classes(classes, reflections) + + return self.aligner.align(classes, reflections, self.fspca_coef) + + def select_classes(self, classes, reflections): + """ + Select the `n_classes` to align from the (n_images) population of classes. + """ + # generate indices for random sample (can do something smart with corr later). + # For testing just take the first n_classes so it matches earlier plots for manual comparison + # This is assumed to be reasonably random. + selection = np.arange(self.n_classes) + return classes[selection], reflections[selection] def pca(self, M): """ @@ -298,61 +315,6 @@ def _sk_nn_classification(self, coeff_b, coeff_b_r): return indices - def output( - self, - classes, - classes_refl, - rot, - shifts=None, - coefs=None, - ): - """ - Return class averages. - - :param classes: class indices (refering to src). (n_img, n_nbor) - :param classes_refl: Bool representing whether to reflect image in `classes` - :param rot: Array of in-plane rotation angles (Radians) of image in `classes` - :param shifts: Optional array of shifts for image in `classes`. - :coefs: Optional Fourier bessel coefs (avoids recomputing). - :return: Stack of Synthetic Class Average images as Image instance. - """ - - logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") - # generate indices for random sample (can do something smart with corr later). - # For testing just take the first n_classes so it matches earlier plots for manual comparison - # This is assumed to be reasonably random. - selection = np.arange(self.n_classes) - - imgs = self.src.images(0, self.src.n) - fb_avgs = np.empty((self.n_classes, self.fb_basis.count), dtype=self.src.dtype) - - for i in tqdm(range(self.n_classes)): - j = selection[i] - # Get the neighbors - neighbors_ids = classes[j] - - # Get coefs in Fourier Bessel Basis if not provided as an argument. - if coefs is None: - neighbors_imgs = Image(imgs[neighbors_ids]) - if shifts is not None: - neighbors_imgs.shift(shifts[i]) - neighbors_coefs = self.fb_basis.evaluate_t(neighbors_imgs) - else: - neighbors_coefs = coefs[neighbors_ids] - if shifts is not None: - neighbors_coefs = self.fb_basis.shift(neighbors_coefs, shifts[i]) - - # Rotate in Fourier Bessel - neighbors_coefs = self.fb_basis.rotate( - neighbors_coefs, rot[j], classes_refl[j] - ) - - # Averaging in FB - fb_avgs[i] = np.mean(neighbors_coefs, axis=0) - - # Now we convert the averaged images from FB to Cartesian. - return ArrayImageSource(self.fb_basis.evaluate(fb_avgs)) - def _legacy_nn_classification(self, coeff_b, coeff_b_r, batch_size=2000): """ Perform nearest neighbor classification and alignment. diff --git a/tests/test_align2d.py b/tests/test_align2d.py index e4f3bc9085..abb6af5925 100644 --- a/tests/test_align2d.py +++ b/tests/test_align2d.py @@ -6,7 +6,7 @@ import pytest from aspire.basis import DiracBasis, FFBBasis2D -from aspire.classification import Align2D, BFRAlign2D, BFSRAlign2D +from aspire.classification import AveragedAlign2D, BFRAlign2D, BFSRAlign2D from aspire.source import Simulation from aspire.utils import Rotation from aspire.volume import Volume @@ -21,7 +21,7 @@ @pytest.mark.filterwarnings("ignore:Gimbal lock detected") class Align2DTestCase(TestCase): # Subclasses should override `aligner` with a different class. - aligner = Align2D + aligner = AveragedAlign2D def setUp(self): @@ -48,6 +48,11 @@ def inject_fixtures(self, caplog): def tearDown(self): pass + def _getSrc(self): + # Base Align2D does not require anything from source. + # Subclasses implement specific src + return None + def testTypeMismatch(self): # Intentionally mismatch Basis and Aligner dtypes @@ -57,7 +62,7 @@ def testTypeMismatch(self): test_dtype = np.float32 with self._caplog.at_level(logging.WARN): - self.aligner(self.basis, dtype=test_dtype) + self.aligner(self.basis, self._getSrc, dtype=test_dtype) assert " does not match self.dtype" in self._caplog.text def _construct_rotations(self): @@ -91,6 +96,7 @@ def r(theta): self.rots = Rotation.from_matrix(_rots) +@pytest.mark.filterwarnings("ignore:Gimbal lock detected") class BFRAlign2DTestCase(Align2DTestCase): aligner = BFRAlign2D @@ -135,7 +141,7 @@ def testNoRot(self): # and that should raise an error during instantiation. with pytest.raises(RuntimeError, match=r".* must provide a `rotate` method."): - _ = self.aligner(basis) + _ = self.aligner(basis, self._getSrc()) def testAlign(self): """ @@ -145,9 +151,10 @@ def testAlign(self): """ # Construction the Aligner and then call the main `align` method - _classes, _reflections, _rotations, _shifts, _ = self.aligner( - self.basis, n_angles=self.n_search_angles - ).align(self.classes, self.reflections, self.coefs) + algnr = self.aligner(self.basis, self._getSrc(), n_angles=self.n_search_angles) + _, _classes, _reflections, _rotations, _shifts, _ = algnr.align( + self.classes, self.reflections, self.coefs + ) self.assertTrue(np.all(_classes == self.classes)) self.assertTrue(np.all(_reflections == self.reflections)) @@ -162,6 +169,7 @@ def testAlign(self): ) +@pytest.mark.filterwarnings("ignore:Gimbal lock detected") class BFSRAlign2DTestCase(BFRAlign2DTestCase): aligner = BFSRAlign2D @@ -192,7 +200,7 @@ def testNoShift(self): # and that should raise an error during instantiation. with pytest.raises(RuntimeError, match=r".* must provide a `shift` method."): - _ = self.aligner(basis) + _ = self.aligner(basis, self._getSrc()) def testAlign(self): """ @@ -202,9 +210,16 @@ def testAlign(self): """ # Construction the Aligner and then call the main `align` method - _classes, _reflections, _rotations, _shifts, _ = self.aligner( - self.basis, n_angles=self.n_search_angles, n_x_shifts=1, n_y_shifts=1 - ).align(self.classes, self.reflections, self.coefs) + algnr = self.aligner( + self.basis, + self._getSrc(), + n_angles=self.n_search_angles, + n_x_shifts=1, + n_y_shifts=1, + ) + _, _classes, _reflections, _rotations, _shifts, _ = algnr.align( + self.classes, self.reflections, self.coefs + ) self.assertTrue(np.all(_classes == self.classes)) self.assertTrue(np.all(_reflections == self.reflections)) diff --git a/tests/test_class2D.py b/tests/test_class2D.py index 0dac11e5a6..a7aea3fd3c 100644 --- a/tests/test_class2D.py +++ b/tests/test_class2D.py @@ -191,8 +191,8 @@ def testRIRLegacy(self): bispectrum_implementation="legacy", ) - result = rir.classify() - _ = rir.output(*result[:3]) + classification_results = rir.classify() + _ = rir.averages(*classification_results) def testRIRDevelBisp(self): """ @@ -207,8 +207,8 @@ def testRIRDevelBisp(self): bispectrum_implementation="devel", ) - result = rir.classify() - _ = rir.output(*result[:3]) + classification_results = rir.classify() + _ = rir.averages(*classification_results) def testRIRsk(self): """ @@ -226,11 +226,17 @@ def testRIRsk(self): large_pca_implementation="sklearn", nn_implementation="sklearn", bispectrum_implementation="devel", - aligner=BFSRAlign2D(self.noisy_fspca_basis, n_angles=100, n_x_shifts=0), + aligner=BFSRAlign2D( + self.noisy_fspca_basis, + self.noisy_src, + self.basis, + n_angles=100, + n_x_shifts=0, + ), ) - result = rir.classify() - _ = rir.output(*result[:4]) + classification_results = rir.classify() + _ = rir.averages(*classification_results) def testEigenImages(self): """ From d90811913422ab1c9c234d64b0e288803a9a7b22 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Jan 2022 15:46:21 -0500 Subject: [PATCH 02/40] Add Reddy Chatterji Log Polar image aligner, and BFS Reddy Chatterji. --- gallery/tutorials/class_averaging.py | 12 +- setup.py | 1 + src/aspire/classification/__init__.py | 2 + src/aspire/classification/align2d.py | 685 ++++++++++++++++++++++- src/aspire/classification/rir_class2d.py | 9 +- tests/test_align2d.py | 4 +- 6 files changed, 682 insertions(+), 31 deletions(-) diff --git a/gallery/tutorials/class_averaging.py b/gallery/tutorials/class_averaging.py index dcf1272b60..7d973e33c8 100644 --- a/gallery/tutorials/class_averaging.py +++ b/gallery/tutorials/class_averaging.py @@ -116,13 +116,15 @@ bispectrum_implementation="legacy", ) -classes, reflections, rotations, shifts, corr = rir.classify() +classes, reflections, dists = rir.classify() +avgs, classes, reflections, rotations, shifts, corrs = rir.averages( + classes, reflections, dists +) # %% # Display Classes # ^^^^^^^^^^^^^^^ -avgs = rir.output(classes, reflections, rotations) avgs.images(0, 10).show() # %% @@ -169,13 +171,15 @@ bispectrum_implementation="legacy", ) -classes, reflections, rotations, shifts, corr = noisy_rir.classify() +classes, reflections, dists = noisy_rir.classify() +avgs, classes, reflections, rotations, shifts, corrs = noisy_rir.averages( + classes, reflections, dists +) # %% # Display Classes # ^^^^^^^^^^^^^^^ -avgs = noisy_rir.output(classes, reflections, rotations) avgs.images(0, 10).show() diff --git a/setup.py b/setup.py index d9c4fae207..20fb9dedb4 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ def read(fname): "pillow", "scipy==1.7.3", "scikit-learn", + "scikit-image", "setuptools>=0.41", "tqdm", ], diff --git a/src/aspire/classification/__init__.py b/src/aspire/classification/__init__.py index b7ca5ef5f7..21fbde28e4 100644 --- a/src/aspire/classification/__init__.py +++ b/src/aspire/classification/__init__.py @@ -3,8 +3,10 @@ AveragedAlign2D, BFRAlign2D, BFSRAlign2D, + BFSReddyChatterjiAlign2D, EMAlign2D, FTKAlign2D, + ReddyChatterjiAlign2D, ) from .class2d import Class2D from .rir_class2d import RIRClass2D diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index 7cf800cbf6..83c96bcf66 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -2,11 +2,17 @@ from abc import ABC, abstractmethod from itertools import product +import matplotlib.pyplot as plt import numpy as np +from skimage.filters import difference_of_gaussians, window + +# import skimage.io +from skimage.transform import rotate, warp_polar from tqdm import tqdm, trange from aspire.image import Image from aspire.source import ArrayImageSource +from aspire.utils.coor_trans import grid_2d logger = logging.getLogger(__name__) @@ -16,7 +22,9 @@ class Align2D(ABC): Base class for 2D Image Alignment methods. """ - def __init__(self, alignment_basis, source, composite_basis=None, dtype=None): + def __init__( + self, alignment_basis, source, composite_basis=None, batch_size=512, dtype=None + ): """ :param alignment_basis: Basis to be used during alignment (eg FSPCA) :param source: Source of original images. @@ -28,14 +36,27 @@ def __init__(self, alignment_basis, source, composite_basis=None, dtype=None): # if composite_basis is None, use alignment_basis self.composite_basis = composite_basis or self.alignment_basis self.src = source + self.batch_size = batch_size if dtype is None: - self.dtype = self.alignment_basis.dtype + if self.composite_basis: + self.dtype = self.composite_basis.dtype + elif self.src: + self.dtype = self.src.dtype + else: + raise RuntimeError("You must supply a basis/src/dtype.") else: self.dtype = np.dtype(dtype) - if self.dtype != self.alignment_basis.dtype: - logger.warning( - f"Align2D alignment_basis.dtype {self.alignment_basis.dtype} does not match self.dtype {self.dtype}." - ) + + if self.src and self.dtype != self.src.dtype: + logger.warning( + f"{self.__class__.__name__} dtype {dtype}" + "does not match dtype of source {self.src.dtype}." + ) + if self.composite_basis and self.dtype != self.composite_basis.dtype: + logger.warning( + f"{self.__class__.__name__} dtype {dtype}" + "does not match dtype of basis {self.composite_basis.dtype}." + ) @abstractmethod def align(self, classes, reflections, basis_coefficients): @@ -44,10 +65,10 @@ def align(self, classes, reflections, basis_coefficients): and return aligned images. During this process `rotations`, `reflections`, `shifts` and - `correlations` propeties will be computed for aligners + `correlations` properties will be computed for aligners that implement them. - `rotations` would be an (n_classes, n_nbor) array of angles, + `rotations` is an (n_classes, n_nbor) array of angles, which should represent the rotations needed to align images within that class. `rotations` is measured in Radians. @@ -68,6 +89,46 @@ def align(self, classes, reflections, basis_coefficients): :returns: Image instance (stack of images) """ + def _images(self, cls): + """ + Util to return images as an array for class k (provided as array `cls` ), + preserving the class/nbor order. + + :param cls: An iterable (0/1-D array or list) that holds the indices of images to align. In Class Averaging, this would be a class. + """ + + n_nbor = cls.shape[-1] # Includes zero'th neighbor + + # Get the images. We'll loop over the source in batches. + # Note one day when the Source.images is more flexible, + # this code would mostly go away. + images = np.empty((n_nbor, self.src.L, self.src.L), dtype=self.dtype) + + # We want to only process batches that actually + # contain images for this class. + # First compute the batches' indices. + for start in range(0, self.src.n + 1, self.batch_size): + # First cook up the batch boundaries + end = start + self.batch_size + # UBound, these are inclusive bounds + start = min(start, self.src.n - 1) + end = min(end, self.src.n - 1) + num = end - start + 1 + + # Second, loop over the cls members + image_batch = None + for i, index in enumerate(cls): + # Check if the member is in this chunk + if start <= index <= end: + # Get and cache this image_batch on first hit. + if image_batch is None: + image_batch = self.src.images(start, num) + # Translate the cls's index into this batch's + batch_index = index % self.batch_size + images[i] = image_batch[batch_index] + + return images + class AveragedAlign2D(Align2D): """ @@ -104,21 +165,23 @@ def average( """ n_classes, n_nbor = classes.shape - # TODO: don't load all the images here. - imgs = self.src.images(0, self.src.n) b_avgs = np.empty((n_classes, self.composite_basis.count), dtype=self.src.dtype) for i in tqdm(range(n_classes)): - # Get the neighbors - neighbors_ids = classes[i] - # Get coefs in Composite_Basis if not provided as an argument. + # Get coefs in Composite_Basis if not provided as an argumen. if coefs is None: - neighbors_imgs = Image(imgs[neighbors_ids]) + # Retrieve relavent images directly from source. + neighbors_imgs = Image(self._images(classes[i])) + + # Do shifts if shifts is not None: neighbors_imgs.shift(shifts[i]) + neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) else: + # Get the neighbors + neighbors_ids = classes[i] neighbors_coefs = coefs[neighbors_ids] if shifts is not None: neighbors_coefs = self.composite_basis.shift( @@ -147,14 +210,20 @@ class BFRAlign2D(AveragedAlign2D): """ def __init__( - self, alignment_basis, source, composite_basis=None, n_angles=359, dtype=None + self, + alignment_basis, + source, + composite_basis=None, + n_angles=359, + batch_size=512, + dtype=None, ): """ :params alignment_basis: Basis providing a `rotate` method. :param source: Source of original images. :params n_angles: Number of brute force rotations to attempt, defaults 359. """ - super().__init__(alignment_basis, source, composite_basis, dtype) + super().__init__(alignment_basis, source, composite_basis, batch_size, dtype) self.n_angles = n_angles @@ -230,6 +299,7 @@ def __init__( n_angles=359, n_x_shifts=1, n_y_shifts=1, + batch_size=512, dtype=None, ): """ @@ -245,7 +315,14 @@ def __init__( :params n_x_shifts: +- Number of brute force xshifts to attempt, defaults 1. :params n_y_shifts: +- Number of brute force xshifts to attempt, defaults 1. """ - super().__init__(alignment_basis, source, composite_basis, n_angles, dtype) + super().__init__( + alignment_basis, + source, + composite_basis, + n_angles, + batch_size=batch_size, + dtype=dtype, + ) self.n_x_shifts = n_x_shifts self.n_y_shifts = n_y_shifts @@ -293,7 +370,7 @@ def _align(self, classes, reflections, basis_coefficients): # Loop over shift search space, updating best result for x, y in product(x_shifts, y_shifts): shift = np.array([x, y], dtype=int) - logger.info(f"Computing Rotational alignment after shift ({x},{y}).") + logger.debug(f"Computing Rotational alignment after shift ({x},{y}).") # Shift the coef representing the first (base) entry in each class # by the negation of the shift @@ -317,18 +394,586 @@ def _align(self, classes, reflections, basis_coefficients): basis_coefficients[classes[:, 0], :] = original_coef if (x, y) == (0, 0): - logger.info("Initial rotational alignment complete (shift (0,0))") + logger.debug("Initial rotational alignment complete (shift (0,0))") assert np.sum(improved_indices) == np.size( classes ), f"{np.sum(improved_indices)} =?= {np.size(classes)}" else: - logger.info( + logger.debug( f"Shift ({x},{y}) complete. Improved {np.sum(improved_indices)} alignments." ) return classes, reflections, rotations, shifts, correlations +class ReddyChatterjiAlign2D(AveragedAlign2D): + """ + Attempts rotational estimation using Reddy Chatterji log polar Fourier cross correlation. + Then attempts shift (translational) estimation using cross correlation. + + When averaging, performs rotations then shifts. + + Note, it may be possible to iterate this algorithm... + + Adopted from Reddy Chatterji (1996) + An FFT-Based Technique for Translation, + Rotation, and Scale-Invariant Image Registration + IEEE TRANSACTIONS ON IMAGE PROCESSING, VOL. 5, NO. 8, AUGUST 1996 + + This method intentionally does not use any of ASPIRE's basis + so that it may be used as a reference for more ASPIRE approaches. + """ + + def __init__( + self, + alignment_basis, + source, + composite_basis=None, + diagnostics=False, + batch_size=512, + dtype=None, + ): + """ + :param alignment_basis: Basis to be used during alignment (eg FSPCA) + :param source: Source of original images. + :param composite_basis: Basis to be used during class average composition (eg FFB2D) + :param dtype: Numpy dtype to be used during alignment. + """ + + self.__cache = dict() + self.diagnostics = diagnostics + self.do_cross_corr_translations = True + + super().__init__( + alignment_basis, source, composite_basis, batch_size=batch_size, dtype=dtype + ) + + def _phase_cross_correlation(self, img0, img1): + """ + # Adapted from skimage.registration.phase_cross_correlation + + :param img0: Fixed image. + :param img1: Translated image. + :returns: (cross-correlation magnitudes (2D array), shifts) + """ + + # Cache img0 transform, this saves n_classes*(n_nbor-1) transforms + # Note we use the `id` because ndarray are unhashable + src_f = self.__cache.setdefault(id(img0), np.fft.fft2(img0)) + + target_f = np.fft.fft2(img1) + + # Whole-pixel shifts - Compute cross-correlation by an IFFT + shape = src_f.shape + image_product = src_f * target_f.conj() + cross_correlation = np.fft.ifft2(image_product) + + # Locate maximum + maxima = np.unravel_index( + np.argmax(np.abs(cross_correlation)), cross_correlation.shape + ) + midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape]) + + shifts = np.array(maxima, dtype=np.float64) + shifts[shifts > midpoints] -= np.array(shape)[shifts > midpoints] + + return np.abs(cross_correlation), shifts + + def _align(self, classes, reflections, basis_coefficients): + """ + Performs the actual rotational alignment estimation, + returning parameters needed for averaging. + """ + + # Admit simple case of single case alignment + classes = np.atleast_2d(classes) + reflections = np.atleast_2d(reflections) + + n_classes = classes.shape[0] + + # Instantiate matrices for results + rotations = np.zeros(classes.shape, dtype=self.dtype) + correlations = np.zeros(classes.shape, dtype=self.dtype) + shifts = np.zeros((*classes.shape, 2), dtype=int) + + for k in trange(n_classes): + # # Get the array of images for this class + images = self._images(classes[k]) + + self._reddychatterji( + k, images, classes, reflections, rotations, correlations, shifts + ) + + return classes, reflections, rotations, shifts, correlations + + def _reddychatterji( + self, k, images, classes, reflections, rotations, correlations, shifts + ): + """ + Compute the Reddy Chatterji registering images[1:] to image[0]. + + This differs from papers and published scikit implimentations by + computing the fixed base image[0] pipeline once then reusing. + """ + + # De-Mean + images -= images.mean(axis=(-1, -2))[:, np.newaxis, np.newaxis] + + # Precompute fixed_img data used repeatedly in the loop below. + fixed_img = images[0] + # Difference of Gaussians (Band Filter) + fixed_img_dog = difference_of_gaussians(fixed_img, 1, 4) + # Window Images (Fix spectral boundary) + wfixed_img = fixed_img_dog * window("hann", fixed_img.shape) + # Transform image to Fourier space + fixed_img_fs = np.abs(np.fft.fftshift(np.fft.fft2(wfixed_img))) ** 2 + # Compute Log Polar Transform + radius = fixed_img_fs.shape[0] // 8 # Low Pass + warped_fixed_img_fs = warp_polar( + fixed_img_fs, + radius=radius, + output_shape=fixed_img_fs.shape, + scaling="log", + ) + # Only use half of FFT, because it's symmetrical + warped_fixed_img_fs = warped_fixed_img_fs[: fixed_img_fs.shape[0] // 2, :] + + # Now prepare for rotating original images, + # and searching for translations. + # We start back at the raw fixed_img. + twfixed_img = fixed_img * window("hann", fixed_img.shape) + + # Register image `m` against image[0] + for m in range(1, len(images)): + # Get the image to register + regis_img = images[m] + + # Reflect images when nessecary + if reflections[k][m]: + regis_img = np.flipud(regis_img) + + # Difference of Gaussians (Band Filter) + regis_img_dog = difference_of_gaussians(regis_img, 1, 4) + + # Window Images (Fix spectral boundary) + wregis_img = regis_img_dog * window("hann", regis_img.shape) + + self._input_images_diagnostic( + classes[k][0], wfixed_img, classes[k][m], wregis_img + ) + + # Transform image to Fourier space + regis_img_fs = np.abs(np.fft.fftshift(np.fft.fft2(wregis_img))) ** 2 + + self._windowed_psd_diagnostic( + classes[k][0], fixed_img_fs, classes[k][m], regis_img_fs + ) + + # Compute Log Polar Transform + warped_regis_img_fs = warp_polar( + regis_img_fs, + radius=radius, # Low Pass + output_shape=fixed_img_fs.shape, + scaling="log", + ) + + self._log_polar_diagnostic( + classes[k][0], warped_fixed_img_fs, classes[k][m], warped_regis_img_fs + ) + + # Only use half of FFT, because it's symmetrical + warped_regis_img_fs = warped_regis_img_fs[: fixed_img_fs.shape[0] // 2, :] + + # Compute the Cross_Correlation to estimate rotation + # Note that _phase_cross_correlation uses the mangnitudes (abs()), + # ie it is using both freq and phase information. + cross_correlation, shift = self._phase_cross_correlation( + warped_fixed_img_fs, warped_regis_img_fs + ) + + cross_correlation_score = cross_correlation[:, 0].ravel() + + self._rotation_cross_corr_diagnostic( + cross_correlation, cross_correlation_score + ) + + # Recover the angle from index representing maximal cross_correlation + recovered_angle_degrees = (360 / regis_img_fs.shape[0]) * np.argmax( + cross_correlation_score + ) + + if recovered_angle_degrees > 90: + r = 180 - recovered_angle_degrees + else: + r = -recovered_angle_degrees + + # Dont like this, but I got stumped/frustrated. + # For now, try the hack below, attempting two cases ... + # Most of the papers mention running the whole algo /twice/, + # when admitting reflections, so this hack is not + # the worst you could do :). + # if reflections[k][m]: + # if 0<= r < 90: + # r -= 180 + # Hack + regis_img_estimated = rotate(regis_img, r) + regis_img_rotated_p180 = rotate(regis_img, r + 180) + da = np.dot(fixed_img.flatten(), regis_img_estimated.flatten()) + db = np.dot(fixed_img.flatten(), regis_img_rotated_p180.flatten()) + if db > da: + regis_img_estimated = regis_img_rotated_p180 + r += 180 + + self._rotated_diagnostic( + classes[k][0], + fixed_img, + classes[k][m], + regis_img_estimated, + reflections[k][m], + r, + ) + + # Assign estimated rotations results + rotations[k][m] = -r * np.pi / 180 # Reverse rot and convert to radians + + if self.do_cross_corr_translations: + # Prepare for searching over translations using cross-correlation with the rotated image. + twregis_img = regis_img_estimated * window("hann", regis_img.shape) + cross_correlation, shift = self._phase_cross_correlation( + twfixed_img, twregis_img + ) + + self._translation_cross_corr_diagnostic(cross_correlation) + + # Compute the shifts as integer number of pixels, + shift_x, shift_y = int(shift[1]), int(shift[0]) + # then apply the shifts + regis_img_estimated = np.roll(regis_img_estimated, shift_y, axis=0) + regis_img_estimated = np.roll(regis_img_estimated, shift_x, axis=1) + # Assign estimated shift to results + shifts[k][m] = shift[::-1].astype(int) + + self._averaged_diagnostic( + classes[k][0], + fixed_img, + classes[k][m], + regis_img_estimated, + reflections[k][m], + r, + ) + else: + shift = None # For logger line + + # Estimated `corr` metric + corr = np.dot(fixed_img.flatten(), regis_img_estimated.flatten()) + correlations[k][m] = corr + + logger.debug( + f"Class {k}, ref {classes[k][0]}, Neighbor {m} Index {classes[k][m]}" + f" Estimates: {r}*, Shift: {shift}," + f" Corr: {corr}, Refl?: {reflections[k][m]}" + ) + + # Cleanup some cached stuff for this class + self.__cache.pop(id(warped_fixed_img_fs), None) + self.__cache.pop(id(twfixed_img), None) + + def average( + self, + classes, + reflections, + rotations, + shifts=None, + coefs=None, + ): + """ + This averages classes performing rotations then shifts. + Otherwise is similar to `AveragedAlign2D.average`. + """ + n_classes, n_nbor = classes.shape + + b_avgs = np.empty((n_classes, self.composite_basis.count), dtype=self.src.dtype) + + for i in tqdm(range(n_classes)): + + # Get coefs in Composite_Basis if not provided as an argument. + if coefs is None: + # Retrieve relavent images directly from source. + neighbors_imgs = Image(self._images(classes[i])) + neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) + else: + # Get the neighbors + neighbors_ids = classes[i] + neighbors_coefs = coefs[neighbors_ids] + + # Rotate in composite_basis + neighbors_coefs = self.composite_basis.rotate( + neighbors_coefs, rotations[i], reflections[i] + ) + + # Note shifts are after rotation for this approach! + if shifts is not None: + neighbors_coefs = self.composite_basis.shift(neighbors_coefs, shifts[i]) + + # Averaging in composite_basis + b_avgs[i] = np.mean(neighbors_coefs, axis=0) + + # Now we convert the averaged images from Basis to Cartesian. + return ArrayImageSource(self.composite_basis.evaluate(b_avgs)) + + def _input_images_diagnostic(self, ia, a, ib, b): + if not self.diagnostics: + return + fig, axes = plt.subplots(1, 2) + ax = axes.ravel() + ax[0].set_title(f"Image {ia}") + ax[0].imshow(a) + ax[1].set_title(f"Image {ib}") + ax[1].imshow(b) + plt.show() + + def _windowed_psd_diagnostic(self, ia, a, ib, b): + if not self.diagnostics: + return + fig, axes = plt.subplots(1, 2) + ax = axes.ravel() + ax[0].set_title(f"Image {ia} PSD") + ax[0].imshow(np.log(a)) + ax[1].set_title(f"Image {ib} PSD") + ax[1].imshow(np.log(b)) + plt.show() + + def _log_polar_diagnostic(self, ia, a, ib, b): + if not self.diagnostics: + return + labels = np.arange(0, 360, 60) + y = labels / (360 / a.shape[0]) + + fig, axes = plt.subplots(1, 2) + ax = axes.ravel() + ax[0].set_title(f"Image {ia}") + ax[0].imshow(a) + ax[0].set_yticks(y, minor=False) + ax[0].set_yticklabels(labels) + ax[0].set_ylabel("Theta (Degrees)") + + ax[1].set_title(f"Image {ib}") + ax[1].imshow(b) + ax[1].set_yticks(y, minor=False) + ax[1].set_yticklabels(labels) + plt.show() + + def _rotation_cross_corr_diagnostic( + self, cross_correlation, cross_correlation_score + ): + if not self.diagnostics: + return + labels = [0, 30, 60, 90, -60, -30] + x = y = np.arange(0, 180, 30) / (180 / cross_correlation.shape[0]) + plt.title("Rotation Cross Correlation Map") + plt.imshow(cross_correlation) + plt.xlabel("Scale") + plt.yticks(y, labels, rotation="vertical") + plt.ylabel("Theta (Degrees)") + plt.show() + + plt.plot(cross_correlation_score) + plt.title("Angle vs Cross Correlation Score") + plt.xticks(x, labels) + plt.xlabel("Theta (Degrees)") + plt.ylabel("Cross Correlation Score") + plt.grid() + plt.show() + + def _rotated_diagnostic(self, ia, a, ib, b, sb, rb): + """ + Plot the image after estimated rotation and reflection. + + :param ia: index image `a` + :param a: image `a` + :param ib: index image `b` + :param b: image `b` after reflection `sb` and rotion `rb` + :param sb: Reflection, Boolean + :param rb: Estimated rotation, degrees + """ + + if not self.diagnostics: + return + + fig, axes = plt.subplots(1, 2) + ax = axes.ravel() + ax[0].set_title(f"Image {ia}") + ax[0].imshow(a) + ax[0].grid() + ax[1].set_title(f"Image {ib} Refl: {str(sb)[0]} Rotated {rb:.1f}") + ax[1].imshow(b) + ax[1].grid() + plt.show() + + def _translation_cross_corr_diagnostic(self, cross_correlation): + if not self.diagnostics: + return + plt.title("Translation Cross Correlation Map") + plt.imshow(cross_correlation) + plt.xlabel("x shift (pixels)") + plt.ylabel("y shift (pixels)") + L = self.src.L + labels = [0, 10, 20, 30, 0, -10, -20, -30] + tick_location = [0, 10, 20, 30, L, L - 10, L - 20, L - 30] + plt.xticks(tick_location, labels) + plt.yticks(tick_location, labels) + plt.show() + + def _averaged_diagnostic(self, ia, a, ib, b, sb, rb): + """ + Plot the stacked average image after + estimated rotation and reflections. + + Compare in a three way plot. + + :param ia: index image `a` + :param a: image `a` + :param ib: index image `b` + :param b: image `b` after reflection `sb` and rotion `rb` + :param sb: Reflection, Boolean + :param rb: Estimated rotation, degrees + """ + if not self.diagnostics: + return + fig, axes = plt.subplots(1, 3) + ax = axes.ravel() + ax[0].set_title(f"{ia}") + ax[0].imshow(a) + ax[0].grid() + ax[1].set_title(f"{ib} Refl: {str(sb)[0]} Rot: {rb:.1f}") + ax[1].imshow(b) + ax[1].grid() + ax[2].set_title("Stacked Avg") + plt.imshow((a + b) / 2.0) + ax[2].grid() + plt.show() + + +class BFSReddyChatterjiAlign2D(ReddyChatterjiAlign2D): + """ + Brute Force Shifts (Translations) - ReddyChatterji (Log-Polar) Rotations + + For each shift within `radius`, attempts rotational match using ReddyChatterji. + When averaging, performs shift before rotations, + + Adopted from Reddy Chatterji (1996) + An FFT-Based Technique for Translation, + Rotation, and Scale-Invariant Image Registration + IEEE TRANSACTIONS ON IMAGE PROCESSING, VOL. 5, NO. 8, AUGUST 1996 + + This method intentionally does not use any of ASPIRE's basis + so that it may be used as a reference for more ASPIRE approaches. + """ + + def __init__( + self, + alignment_basis, + source, + composite_basis=None, + radius=None, + diagnostics=False, + batch_size=512, + dtype=None, + ): + """ + :param alignment_basis: Basis to be used during alignment (eg FSPCA) + :param source: Source of original images. + :param composite_basis: Basis to be used during class average composition (eg FFB2D) + :param radius: Brute force translation search radius. + Defaults to source.L//8. + :param diagnostics: Plot interactive diagnostic graphics (for debugging). + :param dtype: Numpy dtype to be used during alignment. + """ + + super().__init__( + alignment_basis, + source, + composite_basis, + diagnostics, + batch_size=batch_size, + dtype=dtype, + ) + + # For brute force we disable the cross_corr translation code + self.do_cross_corr_translations = False + # Assign search radius + self.radius = radius or source.L // 8 + + def _align(self, classes, reflections, basis_coefficients): + """ + Performs the actual rotational alignment estimation, + returning parameters needed for averaging. + """ + + # Admit simple case of single case alignment + classes = np.atleast_2d(classes) + reflections = np.atleast_2d(reflections) + + n_classes, n_nbor = classes.shape + L = self.src.L + + # Instantiate matrices for inner loop, and best results. + _rotations = np.zeros(classes.shape, dtype=self.dtype) + rotations = np.zeros(classes.shape, dtype=self.dtype) + _correlations = np.zeros(classes.shape, dtype=self.dtype) + correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf + _shifts = np.zeros((*classes.shape, 2), dtype=int) + shifts = np.zeros((*classes.shape, 2), dtype=int) + + # We'll brute force all shifts in a grid. + g = grid_2d(L, normalized=False) + disc = g["r"] <= L // 8 # make param later + X, Y = g["x"][disc], g["y"][disc] + + for k in trange(n_classes): + unshifted_images = self._images(classes[k]) + + for xs, ys in zip(X, Y): + s = np.array([xs, ys]) + # Get the array of images for this class + + images = unshifted_images.copy() + # Don't shift the base image + images[1:] = Image(unshifted_images[1:]).shift(s).asnumpy() + + self._reddychatterji( + k, images, classes, reflections, _rotations, _correlations, _shifts + ) + + # Where corr has improved + # update our rolling best results with this loop. + improved = _correlations > correlations + correlations = np.where(improved, _correlations, correlations) + rotations = np.where(improved, _rotations, rotations) + shifts = np.where(shifts, _shifts, shifts) + logger.debug(f"Shift {s} has improved {np.sum(improved)} results") + + return classes, reflections, rotations, shifts, correlations + + def average( + self, + classes, + reflections, + rotations, + shifts=None, + coefs=None, + ): + """ + See AveragedAlign2D.average. + """ + # ReddyChatterjiAlign2D does rotations then shifts. + # For brute force, we'd like shifts then rotations, + # as is done in gerneral via AveragedAlign2D. + return AveragedAlign2D.average( + self, classes, reflections, rotations, shifts, coefs + ) + + class EMAlign2D(Align2D): """ Citation needed. diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index 83755c83ee..f3daf59736 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -6,11 +6,10 @@ from tqdm import tqdm from aspire.basis import FSPCABasis -from aspire.classification import BFRAlign2D, Class2D +from aspire.classification import Class2D +from aspire.classification.align2d import BFSReddyChatterjiAlign2D from aspire.classification.legacy_implementations import bispec_2drot_large, pca_y -from aspire.image import Image from aspire.numeric import ComplexPCA -from aspire.source import ArrayImageSource from aspire.utils.random import rand logger = logging.getLogger(__name__) @@ -173,8 +172,8 @@ def classify(self, diagnostics=False): # When not provided by a user, the aligner is instantiated after # we are certain our pca_basis has been constructed. if self.aligner is None: - self.aligner = BFRAlign2D( - self.pca_basis, self.src, self.fb_basis, dtype=self.dtype + self.aligner = BFSReddyChatterjiAlign2D( + None, self.src, self.fb_basis, dtype=self.dtype ) # Get the expanded coefs in the compressed FSPCA space. diff --git a/tests/test_align2d.py b/tests/test_align2d.py index abb6af5925..612e76a381 100644 --- a/tests/test_align2d.py +++ b/tests/test_align2d.py @@ -62,8 +62,8 @@ def testTypeMismatch(self): test_dtype = np.float32 with self._caplog.at_level(logging.WARN): - self.aligner(self.basis, self._getSrc, dtype=test_dtype) - assert " does not match self.dtype" in self._caplog.text + self.aligner(self.basis, self._getSrc(), dtype=test_dtype) + assert "does not match dtype" in self._caplog.text def _construct_rotations(self): """ From fec9f22f259c9d9d0734d83b83c777c0ec305b38 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 7 Feb 2022 09:54:40 -0500 Subject: [PATCH 03/40] Rough in ability to use seperate alignment and composition sources for RC aligners. --- src/aspire/classification/align2d.py | 91 ++++++++++++++++------------ 1 file changed, 53 insertions(+), 38 deletions(-) diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index 83c96bcf66..4d5fef7398 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -28,7 +28,7 @@ def __init__( """ :param alignment_basis: Basis to be used during alignment (eg FSPCA) :param source: Source of original images. - :param composite_basis: Basis to be used during class average composition (eg FFB2D) + :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D) :param dtype: Numpy dtype to be used during alignment. """ @@ -89,43 +89,24 @@ def align(self, classes, reflections, basis_coefficients): :returns: Image instance (stack of images) """ - def _images(self, cls): + def _images(self, cls, src=None): """ Util to return images as an array for class k (provided as array `cls` ), preserving the class/nbor order. - :param cls: An iterable (0/1-D array or list) that holds the indices of images to align. In Class Averaging, this would be a class. + :param cls: An iterable (0/1-D array or list) that holds the indices of images to align. + In Class Averaging, this would be a class. + :param src: Optionally overridee the src, for example, if you want to use a different + source for a certain operation (ie aignment). """ + src = src or self.src n_nbor = cls.shape[-1] # Includes zero'th neighbor - # Get the images. We'll loop over the source in batches. - # Note one day when the Source.images is more flexible, - # this code would mostly go away. - images = np.empty((n_nbor, self.src.L, self.src.L), dtype=self.dtype) - - # We want to only process batches that actually - # contain images for this class. - # First compute the batches' indices. - for start in range(0, self.src.n + 1, self.batch_size): - # First cook up the batch boundaries - end = start + self.batch_size - # UBound, these are inclusive bounds - start = min(start, self.src.n - 1) - end = min(end, self.src.n - 1) - num = end - start + 1 - - # Second, loop over the cls members - image_batch = None - for i, index in enumerate(cls): - # Check if the member is in this chunk - if start <= index <= end: - # Get and cache this image_batch on first hit. - if image_batch is None: - image_batch = self.src.images(start, num) - # Translate the cls's index into this batch's - batch_index = index % self.batch_size - images[i] = image_batch[batch_index] + images = np.empty((n_nbor, src.L, src.L), dtype=self.dtype) + + for i, index in enumerate(cls): + images[i] = src.images(index, 1).asnumpy() return images @@ -429,20 +410,44 @@ def __init__( alignment_basis, source, composite_basis=None, + alignment_source=None, diagnostics=False, batch_size=512, dtype=None, ): """ - :param alignment_basis: Basis to be used during alignment (eg FSPCA) + :param alignment_basis: Basis to be used during alignment. + For current implementation of ReddyChatterjiAlign2D this should be `None`. + Instead see `alignment_source`. :param source: Source of original images. - :param composite_basis: Basis to be used during class average composition (eg FFB2D) + :param composite_basis: Basis to be used during class average composition. + For current implementation of ReddyChatterjiAlign2D this should be `None`. + Instead this method uses `source` for composition of the averaged stack. + :param alignment_source: Basis to be used during class average composition. + Must be the same resolution as `source`. :param dtype: Numpy dtype to be used during alignment. """ self.__cache = dict() self.diagnostics = diagnostics self.do_cross_corr_translations = True + self.alignment_src = alignment_source or source + + # TODO, for accomodating different resolutions we minimally need to adapt shifting. + # Outside of scope right now, but would make a nice PR later. + if self.alignment_src.L != source.L: + raise RuntimeError("Currently `alignment_src.L` must equal `source.L`") + if self.alignment_src.dtype != source.dtype: + raise RuntimeError( + "Currently `alignment_src.dtype` must equal `source.dtype`" + ) + + # Sanity check. This API should be rethought once all basis and + # alignment methods have been incorporated. + assert alignment_basis is None # We use sources directly for alignment + assert ( + composite_basis is not None + ) # However, we require a basis for rotating etc. super().__init__( alignment_basis, source, composite_basis, batch_size=batch_size, dtype=dtype @@ -497,8 +502,8 @@ def _align(self, classes, reflections, basis_coefficients): shifts = np.zeros((*classes.shape, 2), dtype=int) for k in trange(n_classes): - # # Get the array of images for this class - images = self._images(classes[k]) + # # Get the array of images for this class, using the `alignment_src`. + images = self._images(classes[k], src=self.alignment_src) self._reddychatterji( k, images, classes, reflections, rotations, correlations, shifts @@ -817,7 +822,7 @@ def _translation_cross_corr_diagnostic(self, cross_correlation): plt.imshow(cross_correlation) plt.xlabel("x shift (pixels)") plt.ylabel("y shift (pixels)") - L = self.src.L + L = self.alignment_src.L labels = [0, 10, 20, 30, 0, -10, -20, -30] tick_location = [0, 10, 20, 30, L, L - 10, L - 20, L - 30] plt.xticks(tick_location, labels) @@ -875,17 +880,26 @@ def __init__( alignment_basis, source, composite_basis=None, + alignment_source=None, radius=None, diagnostics=False, batch_size=512, dtype=None, ): """ - :param alignment_basis: Basis to be used during alignment (eg FSPCA) + :param alignment_basis: Basis to be used during alignment. + For current implementation of ReddyChatterjiAlign2D this should be `None`. + Instead see `alignment_source`. :param source: Source of original images. - :param composite_basis: Basis to be used during class average composition (eg FFB2D) + :param composite_basis: Basis to be used during class average composition. + For current implementation of ReddyChatterjiAlign2D this should be `None`. + Instead this method uses `source` for composition of the averaged stack. + :param alignment_source: Basis to be used during class average composition. + Must be the same resolution as `source`. :param radius: Brute force translation search radius. Defaults to source.L//8. + :param dtype: Numpy dtype to be used during alignment. + :param diagnostics: Plot interactive diagnostic graphics (for debugging). :param dtype: Numpy dtype to be used during alignment. """ @@ -894,6 +908,7 @@ def __init__( alignment_basis, source, composite_basis, + alignment_source, diagnostics, batch_size=batch_size, dtype=dtype, @@ -915,7 +930,7 @@ def _align(self, classes, reflections, basis_coefficients): reflections = np.atleast_2d(reflections) n_classes, n_nbor = classes.shape - L = self.src.L + L = self.alignment_src.L # Instantiate matrices for inner loop, and best results. _rotations = np.zeros(classes.shape, dtype=self.dtype) From 2cffd1ad60c11ae22e47211fa8897c776bed7669 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 7 Feb 2022 10:50:50 -0500 Subject: [PATCH 04/40] Rough in code to write out NN graph as Weighted Adjacency List --- src/aspire/classification/rir_class2d.py | 62 +++++++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index f3daf59736..947d367f85 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -29,6 +29,7 @@ def __init__( bispectrum_freq_cutoff=None, large_pca_implementation="legacy", nn_implementation="legacy", + output_nn_filename=None, bispectrum_implementation="legacy", aligner=None, dtype=None, @@ -47,7 +48,8 @@ def __init__( Z. Zhao, Y. Shkolnisky, A. Singer, Rotationally Invariant Image Representation for Viewing Direction Classification in Cryo-EM. (2014) - :param src: Source instance + :param src: Source instance. Note it is possible to use one `source` for classification (ie CWF), + and a different `source` for stacking in the `aligner`. :param pca_basis: Optional FSPCA Basis instance :param fspca_components: Components (top eigvals) to keep from full FSCPA, default truncates to 400. :param alpha: Amplitude Power Scale, default 1/3 (eq 20 from RIIR paper). @@ -119,6 +121,7 @@ def __init__( f"Provided nn_implementation={nn_implementation} not in {nn_implementations.keys()}" ) self._nn_classification = nn_implementations[nn_implementation] + self.output_nn_filename = output_nn_filename # # Do we have a sane Large Dataset PCA large_pca_implementations = { @@ -185,6 +188,8 @@ def classify(self, diagnostics=False): # # Stage 2: Compute Nearest Neighbors logger.info("Calculate Nearest Neighbors") classes, reflections, distances = self.nn_classification(coef_b, coef_b_r) + if self.output_nn_filename is not None: + self._save_nn(classes, reflections, distances) if diagnostics: # Lets peek at the distribution of distances @@ -351,7 +356,7 @@ def _legacy_nn_classification(self, coeff_b, coeff_b_r, batch_size=2000): # Check with Joakim about preference. # I (GBW) think class[i] should have class[i][0] be the original image index. classes[start:finish] = np.argsort(-corr, axis=1)[:, :n_nbor] - # Store the corr values for the n_nhors in this batch + # Store the corr values for the n_nbors in this batch distances[start:finish] = np.take_along_axis( corr, classes[start:finish], axis=1 ) @@ -366,6 +371,59 @@ def _legacy_nn_classification(self, coeff_b, coeff_b_r, batch_size=2000): return classes, refl, distances + def _save_nn(self, classes, reflections, distances): + """ + Output the Nearest Neighbors graph as a weighted adjacency list. + + Vertices are indexed by their natural index in `source`. + Note reflected images are represented by `index + src.n`. + + Only the output of the Nearest Neighbor call is saved. + If you want a complete graph, specify 2*src.n neighbors, + that is all images and their reflections. + + Because this is mixed datatypes (int and floating), + this will be output as a space delimited text file. + + Vi1 Vj1 W_i1_j1 Vj2 Wi1_j2 ... + Vi2 Vj1 W_i2_j1 Vj2 Wi2_j2 ... + ... + + """ + + # Construct the weighted adjacency list + AdjList = [] + for k in range(len(classes)): + + row = [] + vik = classes[k][0] + row.append(vik) + + for j in range(1, len(classes[k])): + + # Neighbor index + vj = classes[k][j] + if reflections[k][j]: + vj += self.src.n + row.append(vj) + + # Neighbor Weight (distance) + wt = distances[k][j] + row.append(wt) + + # Store this row of the AdjList + AdjList.append(row) + + logger.info( + "Writing Nearest Neighbors as Weighted Adjacency List" + f" to {self.output_nn_filename}" + ) + + # Output + with open(self.output_nn_filename, "w") as fh: + for row in AdjList: + fh.write(" ".join(str(x) for x in row) + "\n") + def _legacy_pca(self, M): """ This is more or less the historic implementation ported From 9c876c79a762195d40dbd4c29c9de10e081b3a39 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 7 Feb 2022 11:11:21 -0500 Subject: [PATCH 05/40] Some cleanup --- src/aspire/classification/align2d.py | 10 +++------- src/aspire/classification/rir_class2d.py | 18 ++++++++---------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index 4d5fef7398..16fb84dff8 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -135,7 +135,7 @@ def average( coefs=None, ): """ - Combines images using averaging in provided `basis`. + Combines images using averaging in `composite_basis`. :param classes: class indices (refering to src). (n_img, n_nbor) :param reflections: Bool representing whether to reflect image in `classes` @@ -515,7 +515,7 @@ def _reddychatterji( self, k, images, classes, reflections, rotations, correlations, shifts ): """ - Compute the Reddy Chatterji registering images[1:] to image[0]. + Compute the Reddy Chatterji registering images[1:] to image[0]. This differs from papers and published scikit implimentations by computing the fixed base image[0] pipeline once then reusing. @@ -612,14 +612,10 @@ def _reddychatterji( else: r = -recovered_angle_degrees - # Dont like this, but I got stumped/frustrated. # For now, try the hack below, attempting two cases ... - # Most of the papers mention running the whole algo /twice/, + # Some papers mention running entire algos /twice/, # when admitting reflections, so this hack is not # the worst you could do :). - # if reflections[k][m]: - # if 0<= r < 90: - # r -= 180 # Hack regis_img_estimated = rotate(regis_img, r) regis_img_rotated_p180 = rotate(regis_img, r + 180) diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index 947d367f85..c1aeaaa6c7 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -6,8 +6,7 @@ from tqdm import tqdm from aspire.basis import FSPCABasis -from aspire.classification import Class2D -from aspire.classification.align2d import BFSReddyChatterjiAlign2D +from aspire.classification import BFSReddyChatterjiAlign2D, Class2D from aspire.classification.legacy_implementations import bispec_2drot_large, pca_y from aspire.numeric import ComplexPCA from aspire.utils.random import rand @@ -207,29 +206,28 @@ def classify(self, diagnostics=False): def averages(self, classes, reflections, distances): # # Stage 3: Class Selection - logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") # This is an area open to active research. # Currently we take a naive approach by selecting the # first n_classes assuming they are quasi random. - classes = classes[: self.n_classes] + logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") + classes, reflections = self.select_classes(classes, reflections) # # Stage 4: Align logger.info( f"Begin Rotational Alignment of {classes.shape[0]} Classes using {self.aligner}." ) - logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") - classes, reflections = self.select_classes(classes, reflections) - return self.aligner.align(classes, reflections, self.fspca_coef) def select_classes(self, classes, reflections): """ Select the `n_classes` to align from the (n_images) population of classes. """ - # generate indices for random sample (can do something smart with corr later). - # For testing just take the first n_classes so it matches earlier plots for manual comparison - # This is assumed to be reasonably random. + # Generate indices for random sample (can do something smarter, or build this out later). + # For testing/poc just take the first n_classes so it matches earlier plots for manual comparison + # If image stack is assumed to be reasonably random, this is a reasonable thing to do. + # Another reasonable thing would be to take a random selection over the whole dataset, + # in case the head of a dataset is too similar or has artifacts. selection = np.arange(self.n_classes) return classes[selection], reflections[selection] From acc15643cbd8873984ee96573e6dadc6f4c73e99 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 7 Feb 2022 11:35:25 -0500 Subject: [PATCH 06/40] add simulated abinitio pipeline experiment --- .../simulated_abinitio_pipeline.py | 229 ++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 gallery/experiments/simulated_abinitio_pipeline.py diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py new file mode 100644 index 0000000000..300ac08e2e --- /dev/null +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -0,0 +1,229 @@ +""" +ASPIRE-Python Abinitio Pipeline +================================ + +In this notebook we will introduce a selection of +components corresponding to a pipeline. +""" + +# %% +# Imports +# ------- +# First we import some of the usual suspects. +# In addition, we import some classes from +# the ASPIRE package that we will use throughout this experiment. + +import logging + +import matplotlib.pyplot as plt +import numpy as np + +from aspire.abinitio import CLSyncVoting +from aspire.basis import FFBBasis3D +from aspire.classification import RIRClass2D +from aspire.denoising import DenoiserCov2D +from aspire.noise import AnisotropicNoiseEstimator +from aspire.operators import FunctionFilter, RadialCTFFilter +from aspire.reconstruction import MeanEstimator +from aspire.source import ArrayImageSource, Simulation +from aspire.utils.coor_trans import ( + get_aligned_rotations, + get_rots_mse, + register_rotations, +) +from aspire.volume import Volume + +logger = logging.getLogger(__name__) + + +# %% +# Parameters +# --------------- +# Some example simulation configurations. +# Small sim: img_size 32, num_imgs 10000, n_classes 1000, n_nbor 10 +# Medium sim: img_size 64, num_imgs 20000, n_classes 2000, n_nbor 10 +# Large sim: img_size 129, num_imgs 30000, n_classes 2000, n_nbor 20 + +interactive = True # Do we want to draw blocking interactive plots? +do_cov2d = False # Use CWF coefficients +img_size = 32 # Downsample the volume to a desired resolution +num_imgs = 10000 # How many images in our source. +n_classes = 1000 # How many class averages to compute. +n_nbor = 10 # How many neighbors to stack +noise_variance = 1e-4 # Set a target noise variance + + +# %% +# Simulation Data +# --------------- +# We'll start with a fairly hi-res volume available from EMPIAR/EMDB. +# https://www.ebi.ac.uk/emdb/EMD-2660 +# https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-2660/map/emd_2660.map.gz +og_v = Volume.load("emd_2660.map", dtype=np.float64) +logger.info("Original volume map data" f" shape: {og_v.shape} dtype:{og_v.dtype}") + +logger.info(f"Downsampling to {(img_size,)*3}") +v = og_v.downsample(img_size) +L = v.resolution + + +# Then create a filter based on that variance +# This is an example of a custom noise profile +def noise_function(x, y): + alpha = 1 + beta = 1 + # White + f1 = noise_variance + # Violet-ish + f2 = noise_variance * (x * x + y * y) / L * L + return (alpha * f1 + beta * f2) / 2.0 + + +custom_noise_filter = FunctionFilter(noise_function) + +logger.info("Initialize CTF filters.") +# Create some CTF effects +pixel_size = 5 * 65 / img_size # Pixel size of the images (in angstroms) +voltage = 200 # Voltage (in KV) +defocus_min = 1.5e4 # Minimum defocus value (in angstroms) +defocus_max = 2.5e4 # Maximum defocus value (in angstroms) +defocus_ct = 7 # Number of defocus groups. +Cs = 2.0 # Spherical aberration +alpha = 0.1 # Amplitude contrast + +# Create filters +ctf_filters = [ + RadialCTFFilter(pixel_size, voltage, defocus=d, Cs=2.0, alpha=0.1) + for d in np.linspace(defocus_min, defocus_max, defocus_ct) +] + +# Finally create the Simulation +src = Simulation( + L=v.resolution, + n=num_imgs, + vols=v, + noise_filter=custom_noise_filter, + unique_filters=ctf_filters, +) +# Peek +if interactive: + src.images(0, 10).show() + +# # TODO: Seemed to cause a crash, maybe dtype/blkdiag related +# logger.info("Normalize images to background noise.") +# src.normalize_background() +# # Peek +# if interactive: src.images(0, 10).show() + +# Currently we use phase_flip to attempt correcting for CTF. +logger.info("Perform phase flip to input images.") +src.phase_flip() + +# We should estimate the noise and `Whiten` based on the estimated noise +aiso_noise_estimator = AnisotropicNoiseEstimator(src) +src.whiten(aiso_noise_estimator.filter) + +# Plot the noise profile for inspection +if interactive: + plt.imshow(aiso_noise_estimator.filter.evaluate_grid(L)) + plt.show() + +# Peek, what do the whitened images look like... +if interactive: + src.images(0, 10).show() + +# logger.info("Invert the global density contrast") +# src.invert_contrast() + +# # On Simulation data, better results so far were achieved without cov2d. +if do_cov2d: + # Use CWF denoising + cwf_denoiser = DenoiserCov2D(src) + src = cwf_denoiser.denoise() + +# Peek, what do the denoised images look like... +if interactive: + src.images(0, 10).show() + + +# Cache to memory for some speedup +src = ArrayImageSource(src.images(0, num_imgs).asnumpy(), angles=src.angles) + +# %% +# Class Averaging +# ---------------------- +# +# Now we perform classification and averaging for each class. + +logger.info("Begin Class Averaging") + +rir = RIRClass2D( + src, + fspca_components=400, + bispectrum_components=300, # Compressed Features after last PCA stage. + n_nbor=n_nbor, + n_classes=n_classes, + large_pca_implementation="legacy", + nn_implementation="sklearn", + bispectrum_implementation="legacy", +) + +classes, reflections, distances = rir.classify() +# Only care about the averages returned right now. +avgs = rir.averages(classes, reflections, distances)[0] +if interactive: + avgs.images(0, 10).show() + +# %% +# Common Line Estimation +# ---------------------- +# +# Now we can create a CL instance for estimating orientation of projections +# using the Common Line with Synchronization Voting method. + +logger.info("Begin Orientation Estimation") + +# Stash true rotations for later comparison, +# note this line only works with naive class selection... +true_rotations = src.rots[:n_classes] + +orient_est = CLSyncVoting(avgs, n_theta=36) +# Get the estimated rotations +orient_est.estimate_rotations() +rots_est = orient_est.rotations + +logger.info("Compare with known rotations") +# Compare with known true rotations +Q_mat, flag = register_rotations(rots_est, true_rotations) +regrot = get_aligned_rotations(rots_est, Q_mat, flag) +mse_reg = get_rots_mse(regrot, true_rotations) +logger.info( + f"MSE deviation of the estimated rotations using register_rotations : {mse_reg}\n" +) + +# %% +# Volume Reconstruction +# ---------------------- +# +# Using the estimated rotations, attempt to reconstruct a volume. + +logger.info("Begin Volume reconstruction") + +# Assign the estimated rotations to the class averages +avgs.rots = rots_est + +# Create a reasonable Basis for the 3d Volume +basis = FFBBasis3D((v.resolution,) * 3, dtype=v.dtype) + +# Setup an estimator to perform the back projection. +estimator = MeanEstimator(avgs, basis) + +# Perform the estimation and save the volume. +estimated_volume = estimator.estimate() +fn = f"estimated_volume_n{num_imgs}_c{n_classes}_m{n_nbor}_{img_size}.mrc" +estimated_volume.save(fn, overwrite=True) + +# Peek at result +if interactive: + plt.imshow(np.sum(estimated_volume[0], axis=-1)) + plt.show() From 0ea373b85dd7833a06cd22939059edebc511f1d1 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 7 Feb 2022 11:38:02 -0500 Subject: [PATCH 07/40] minor denoiser patches --- src/aspire/denoising/__init__.py | 3 ++- src/aspire/denoising/denoised_src.py | 3 +++ src/aspire/denoising/denoiser_cov2d.py | 6 +++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/aspire/denoising/__init__.py b/src/aspire/denoising/__init__.py index 30aaf0ccb0..492f1f6aec 100644 --- a/src/aspire/denoising/__init__.py +++ b/src/aspire/denoising/__init__.py @@ -1,3 +1,4 @@ from .adaptive_support import adaptive_support +from .denoised_src import DenoisedImageSource from .denoiser import Denoiser -from .denoiser_cov2d import src_wiener_coords +from .denoiser_cov2d import DenoiserCov2D, src_wiener_coords diff --git a/src/aspire/denoising/denoised_src.py b/src/aspire/denoising/denoised_src.py index db061a0c81..83e2ec3162 100644 --- a/src/aspire/denoising/denoised_src.py +++ b/src/aspire/denoising/denoised_src.py @@ -45,6 +45,9 @@ def _images(self, start=0, num=np.inf, indices=None, batch_size=512): nimgs = len(indices) im = np.empty((nimgs, self.L, self.L)) + # If we request less than a whole batch, don't crash + batch_size = min(nimgs, batch_size) + logger.info(f"Loading {nimgs} images complete") for batch_start in range(start, end + 1, batch_size): imgs_denoised = self.denoiser.images(batch_start, batch_size) diff --git a/src/aspire/denoising/denoiser_cov2d.py b/src/aspire/denoising/denoiser_cov2d.py index 71ef7e97e5..142dcc3765 100644 --- a/src/aspire/denoising/denoiser_cov2d.py +++ b/src/aspire/denoising/denoiser_cov2d.py @@ -103,7 +103,7 @@ class DenoiserCov2D(Denoiser): Define a derived class for denoising 2D images using Cov2D method """ - def __init__(self, src, basis, var_noise=None): + def __init__(self, src, basis=None, var_noise=None): """ Initialize an object for denoising 2D images using Cov2D method @@ -122,8 +122,12 @@ def __init__(self, src, basis, var_noise=None): logger.info(f"Estimated Noise Variance: {var_noise}") self.var_noise = var_noise + if basis is None: + basis = FFBBasis2D((self.src.L, self.src.L)) + if not isinstance(basis, FFBBasis2D): raise NotImplementedError("Currently only fast FB method is supported") + self.basis = basis self.cov2d = None self.mean_est = None From 1cdee3661cc21c516dff7ae01022e63b2b7b7b3f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 7 Feb 2022 12:05:16 -0500 Subject: [PATCH 08/40] Update config to parse (but not execute) gallery `experiment` examples --- docs/source/conf.py | 2 +- gallery/experiments/README.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6cd5cc084b..e0655e3cb4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,7 +49,7 @@ 'gallery_dirs': ['auto_tutorials', 'auto_experiments'], # path to where to save gallery generated output 'download_all_examples': False, 'within_subsection_order': ExampleTitleSortKey, - 'filename_pattern': '/*.py', + 'filename_pattern': r'/tutorials/.*\.py', # Parse all gallery python files, but only execute tutorials. } # Add any paths that contain templates here, relative to this directory. diff --git a/gallery/experiments/README.rst b/gallery/experiments/README.rst index 5791749161..8531b9f7b0 100644 --- a/gallery/experiments/README.rst +++ b/gallery/experiments/README.rst @@ -1,4 +1,4 @@ -Experiments - **COMING SOON** +Experiments ============================= This gallery will be for demonstrating the functionality of ASPIRE tools using experimental data. From 3d54c2247fd9123248c26a25679ee9290e6c9257 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 7 Feb 2022 13:25:30 -0500 Subject: [PATCH 09/40] Fix bug introduced in recent grids merge 3d covar code apparently not under testing.. --- src/aspire/covariance/covar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/covariance/covar.py b/src/aspire/covariance/covar.py index 181b6840e6..397dcb5366 100644 --- a/src/aspire/covariance/covar.py +++ b/src/aspire/covariance/covar.py @@ -61,7 +61,7 @@ def compute_kernel(self): weights[:, 0, :] = 0 # TODO: This is where this differs from MeanEstimator - pts_rot = np.moveaxis(pts_rot, -1, 0).reshape(-1, 3, L**2) + pts_rot = np.moveaxis(pts_rot[::-1], 1, 0).reshape(-1, 3, L**2) weights = weights.T.reshape((-1, L**2)) batch_n = weights.shape[0] From d7df6e39bdb247bd364b16f207500d9291c3d569 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 7 Feb 2022 15:43:10 -0500 Subject: [PATCH 10/40] Fixup the CWF experiment example (some dtypes issues) --- .../simulated_abinitio_pipeline.py | 33 ++++++++++++------- src/aspire/classification/align2d.py | 8 ++--- src/aspire/covariance/covar2d.py | 2 +- src/aspire/denoising/denoiser_cov2d.py | 2 +- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index 300ac08e2e..70978d9207 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -19,8 +19,8 @@ import numpy as np from aspire.abinitio import CLSyncVoting -from aspire.basis import FFBBasis3D -from aspire.classification import RIRClass2D +from aspire.basis import FFBBasis2D, FFBBasis3D +from aspire.classification import BFSReddyChatterjiAlign2D, RIRClass2D from aspire.denoising import DenoiserCov2D from aspire.noise import AnisotropicNoiseEstimator from aspire.operators import FunctionFilter, RadialCTFFilter @@ -132,22 +132,32 @@ def noise_function(x, y): if interactive: src.images(0, 10).show() +# # Optionally invert image contrast, depends on data. # logger.info("Invert the global density contrast") # src.invert_contrast() -# # On Simulation data, better results so far were achieved without cov2d. +# Cache to memory for some speedup +src = ArrayImageSource(src.images(0, num_imgs).asnumpy(), angles=src.angles) + +# On Simulation data, better results so far were achieved without cov2d +# However, we can demonstrate using CWF denoised images for classification. +classification_src = src +custom_aligner = None if do_cov2d: # Use CWF denoising cwf_denoiser = DenoiserCov2D(src) - src = cwf_denoiser.denoise() - -# Peek, what do the denoised images look like... -if interactive: - src.images(0, 10).show() + # Use denoised src for classification + classification_src = cwf_denoiser.denoise() + # Peek, what do the denoised images look like... + if interactive: + classification_src.images(0, 10).show() + # Use regular `src` for the alignment and composition (averaging). + composite_basis = FFBBasis2D((src.L,) * 2, dtype=src.dtype) + custom_aligner = BFSReddyChatterjiAlign2D( + None, src, composite_basis, dtype=src.dtype + ) -# Cache to memory for some speedup -src = ArrayImageSource(src.images(0, num_imgs).asnumpy(), angles=src.angles) # %% # Class Averaging @@ -158,7 +168,7 @@ def noise_function(x, y): logger.info("Begin Class Averaging") rir = RIRClass2D( - src, + classification_src, # Source used for classification fspca_components=400, bispectrum_components=300, # Compressed Features after last PCA stage. n_nbor=n_nbor, @@ -166,6 +176,7 @@ def noise_function(x, y): large_pca_implementation="legacy", nn_implementation="sklearn", bispectrum_implementation="legacy", + aligner=custom_aligner, ) classes, reflections, distances = rir.classify() diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index 16fb84dff8..d02a7129b0 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -421,9 +421,7 @@ def __init__( Instead see `alignment_source`. :param source: Source of original images. :param composite_basis: Basis to be used during class average composition. - For current implementation of ReddyChatterjiAlign2D this should be `None`. - Instead this method uses `source` for composition of the averaged stack. - :param alignment_source: Basis to be used during class average composition. + :param alignment_source: Optional, source to be used during class average alignment. Must be the same resolution as `source`. :param dtype: Numpy dtype to be used during alignment. """ @@ -888,9 +886,7 @@ def __init__( Instead see `alignment_source`. :param source: Source of original images. :param composite_basis: Basis to be used during class average composition. - For current implementation of ReddyChatterjiAlign2D this should be `None`. - Instead this method uses `source` for composition of the averaged stack. - :param alignment_source: Basis to be used during class average composition. + :param alignment_source: Optional, source to be used during class average alignment. Must be the same resolution as `source`. :param radius: Brute force translation search radius. Defaults to source.L//8. diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 95cb734150..b8599f27fb 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -284,7 +284,7 @@ def identity(x): for k in np.unique(ctf_idx[:]): - coeff_k = coeffs[ctf_idx == k] + coeff_k = coeffs[ctf_idx == k].astype(self.dtype) weight = coeff_k.shape[0] / coeffs.shape[0] ctf_fb_k = ctf_fb[k] diff --git a/src/aspire/denoising/denoiser_cov2d.py b/src/aspire/denoising/denoiser_cov2d.py index 142dcc3765..9e3ac9554a 100644 --- a/src/aspire/denoising/denoiser_cov2d.py +++ b/src/aspire/denoising/denoiser_cov2d.py @@ -123,7 +123,7 @@ def __init__(self, src, basis=None, var_noise=None): self.var_noise = var_noise if basis is None: - basis = FFBBasis2D((self.src.L, self.src.L)) + basis = FFBBasis2D((self.src.L, self.src.L), dtype=src.dtype) if not isinstance(basis, FFBBasis2D): raise NotImplementedError("Currently only fast FB method is supported") From e0507965d8d1d1ac4809762609cb3f2a16328047 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 8 Feb 2022 11:15:55 -0500 Subject: [PATCH 11/40] cleanup few strings/typos --- .../simulated_abinitio_pipeline.py | 6 ----- src/aspire/classification/align2d.py | 25 ++++++++++--------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index 70978d9207..59e6aca25f 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -109,12 +109,6 @@ def noise_function(x, y): if interactive: src.images(0, 10).show() -# # TODO: Seemed to cause a crash, maybe dtype/blkdiag related -# logger.info("Normalize images to background noise.") -# src.normalize_background() -# # Peek -# if interactive: src.images(0, 10).show() - # Currently we use phase_flip to attempt correcting for CTF. logger.info("Perform phase flip to input images.") src.phase_flip() diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index d02a7129b0..369718901f 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -61,12 +61,13 @@ def __init__( @abstractmethod def align(self, classes, reflections, basis_coefficients): """ - Any align2D alignment method should take in the following arguments + Any align2D alignment method should take in the below arguments and return aligned images. During this process `rotations`, `reflections`, `shifts` and `correlations` properties will be computed for aligners - that implement them. + that implement them. Some future aligners (example. EM based) + may not produce these intermediates. `rotations` is an (n_classes, n_nbor) array of angles, which should represent the rotations needed to align images within @@ -83,13 +84,13 @@ def align(self, classes, reflections, basis_coefficients): Subclasses of `align` should extend this method with optional arguments. :param classes: (n_classes, n_nbor) integer array of img indices - :param refl: (n_classes, n_nbor) bool array of corresponding reflections - :param coef: (n_img, self.pca_basis.count) compressed basis coefficients + :param reflections: (n_classes, n_nbor) bool array of corresponding reflections + :param basis_coefficients: (n_img, self.pca_basis.count) compressed basis coefficients :returns: Image instance (stack of images) """ - def _images(self, cls, src=None): + def _cls_images(self, cls, src=None): """ Util to return images as an array for class k (provided as array `cls` ), preserving the class/nbor order. @@ -135,7 +136,7 @@ def average( coefs=None, ): """ - Combines images using averaging in `composite_basis`. + Combines images using averaging in `self.composite_basis`. :param classes: class indices (refering to src). (n_img, n_nbor) :param reflections: Bool representing whether to reflect image in `classes` @@ -153,7 +154,7 @@ def average( # Get coefs in Composite_Basis if not provided as an argumen. if coefs is None: # Retrieve relavent images directly from source. - neighbors_imgs = Image(self._images(classes[i])) + neighbors_imgs = Image(self._cls_images(classes[i])) # Do shifts if shifts is not None: @@ -501,7 +502,7 @@ def _align(self, classes, reflections, basis_coefficients): for k in trange(n_classes): # # Get the array of images for this class, using the `alignment_src`. - images = self._images(classes[k], src=self.alignment_src) + images = self._cls_images(classes[k], src=self.alignment_src) self._reddychatterji( k, images, classes, reflections, rotations, correlations, shifts @@ -513,7 +514,7 @@ def _reddychatterji( self, k, images, classes, reflections, rotations, correlations, shifts ): """ - Compute the Reddy Chatterji registering images[1:] to image[0]. + Compute the Reddy Chatterji method registering images[1:] to image[0]. This differs from papers and published scikit implimentations by computing the fixed base image[0] pipeline once then reusing. @@ -551,7 +552,7 @@ def _reddychatterji( # Get the image to register regis_img = images[m] - # Reflect images when nessecary + # Reflect images when necessary if reflections[k][m]: regis_img = np.flipud(regis_img) @@ -698,7 +699,7 @@ def average( # Get coefs in Composite_Basis if not provided as an argument. if coefs is None: # Retrieve relavent images directly from source. - neighbors_imgs = Image(self._images(classes[i])) + neighbors_imgs = Image(self._cls_images(classes[i])) neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) else: # Get the neighbors @@ -938,7 +939,7 @@ def _align(self, classes, reflections, basis_coefficients): X, Y = g["x"][disc], g["y"][disc] for k in trange(n_classes): - unshifted_images = self._images(classes[k]) + unshifted_images = self._cls_images(classes[k]) for xs, ys in zip(X, Y): s = np.array([xs, ys]) From b40b038a7240c79b3c514dd6b5b9e1dfbebbb769 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 8 Feb 2022 12:23:47 -0500 Subject: [PATCH 12/40] Add 10028 experiment pipeline example. --- gallery/experiments/simulated_pipeline.py | 168 ++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 gallery/experiments/simulated_pipeline.py diff --git a/gallery/experiments/simulated_pipeline.py b/gallery/experiments/simulated_pipeline.py new file mode 100644 index 0000000000..657cbd3d3c --- /dev/null +++ b/gallery/experiments/simulated_pipeline.py @@ -0,0 +1,168 @@ +""" +ASPIRE-Python Abinitio Pipeline +================================ + +In this notebook we will introduce a selection of +components corresponding to a pipeline using +the EMD-10028 picked particles dataset. +""" + +# %% +# Imports +# ------- +# First we import some of the usual suspects. +# In addition, we import some classes from +# the ASPIRE package that we will use throughout this experiment. + +import logging + +import matplotlib.pyplot as plt +import numpy as np + +from aspire.abinitio import CLSyncVoting +from aspire.basis import FFBBasis2D, FFBBasis3D +from aspire.classification import BFSReddyChatterjiAlign2D, RIRClass2D +from aspire.denoising import DenoiserCov2D +from aspire.noise import AnisotropicNoiseEstimator +from aspire.operators import FunctionFilter, RadialCTFFilter +from aspire.reconstruction import MeanEstimator +from aspire.source import RelionSource +from aspire.volume import Volume + +logger = logging.getLogger(__name__) + + +# %% +# Parameters +# --------------- +# Example simulation configuration. + +interactive = False # Do we want to draw blocking interactive plots? +do_cov2d = True # Use CWF coefficients +n_imgs = None # Set to None for all images in starfile +img_size = 64 # Downsample the images/reconstruction to a desired resolution +n_classes = 2000 # How many class averages to compute. +n_nbor = 50 # How many neighbors to stack +starfile_in = "10028/data/shiny_2sets.star" +pixel_size = 1.34 + +# Create a source object for the experimental images +src = RelionSource( + starfile_in, pixel_size=pixel_size, max_rows=n_imgs +) + +# Downsample the images +logger.info(f"Set the resolution to {img_size} X {img_size}") +src.downsample(img_size) + +# Peek +if interactive: + src.images(0, 10).show() + +# Currently we use phase_flip to attempt correcting for CTF. +logger.info("Perform phase flip to input images.") +src.phase_flip() + +# We should estimate the noise and `Whiten` based on the estimated noise +aiso_noise_estimator = AnisotropicNoiseEstimator(src) +src.whiten(aiso_noise_estimator.filter) + +# Plot the noise profile for inspection +if interactive: + plt.imshow(aiso_noise_estimator.filter.evaluate_grid(img_size)) + plt.show() + +# Peek, what do the whitened images look like... +if interactive: + src.images(0, 10).show() + +# # Optionally invert image contrast, depends on data. +# logger.info("Invert the global density contrast") +# src.invert_contrast() + +# On Simulation data, better results so far were achieved without cov2d +# However, we can demonstrate using CWF denoised images for classification. +classification_src = src +custom_aligner = None +if do_cov2d: + # Use CWF denoising + cwf_denoiser = DenoiserCov2D(src) + # Use denoised src for classification + classification_src = cwf_denoiser.denoise() + # Peek, what do the denoised images look like... + if interactive: + classification_src.images(0, 10).show() + + # Use regular `src` for the alignment and composition (averaging). + composite_basis = FFBBasis2D((src.L,) * 2, dtype=src.dtype) + custom_aligner = BFSReddyChatterjiAlign2D( + None, src, composite_basis, dtype=src.dtype + ) + + +# %% +# Class Averaging +# ---------------------- +# +# Now we perform classification and averaging for each class. + +logger.info("Begin Class Averaging") + +rir = RIRClass2D( + classification_src, # Source used for classification + fspca_components=400, + bispectrum_components=300, # Compressed Features after last PCA stage. + n_nbor=n_nbor, + n_classes=n_classes, + large_pca_implementation="legacy", + nn_implementation="sklearn", + bispectrum_implementation="legacy", + aligner=custom_aligner, +) + +classes, reflections, distances = rir.classify() +# Only care about the averages returned right now. +avgs = rir.averages(classes, reflections, distances)[0] +if interactive: + avgs.images(0, 10).show() + +# %% +# Common Line Estimation +# ---------------------- +# +# Now we can create a CL instance for estimating orientation of projections +# using the Common Line with Synchronization Voting method. + +logger.info("Begin Orientation Estimation") + +orient_est = CLSyncVoting(avgs, n_theta=36) +# Get the estimated rotations +orient_est.estimate_rotations() +rots_est = orient_est.rotations + +# %% +# Volume Reconstruction +# ---------------------- +# +# Using the estimated rotations, attempt to reconstruct a volume. + +logger.info("Begin Volume reconstruction") + +# Assign the estimated rotations to the class averages +avgs.rots = rots_est + +# Create a reasonable Basis for the 3d Volume +basis = FFBBasis3D((img_size,) * 3, dtype=src.dtype) + +# Setup an estimator to perform the back projection. +estimator = MeanEstimator(avgs, basis) + +# Perform the estimation and save the volume. +estimated_volume = estimator.estimate() +fn = f"estimated_volume_n{num_imgs}_c{n_classes}_m{n_nbor}_{img_size}.mrc" +estimated_volume.save(fn, overwrite=True) + +# Peek at result +if interactive: + plt.imshow(np.sum(estimated_volume[0], axis=-1)) + plt.show() From c1cce91721746b7682b476feac8c11694d074eaf Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 8 Feb 2022 12:51:15 -0500 Subject: [PATCH 13/40] tweak filename output for simulated_pipeline --- gallery/experiments/simulated_pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gallery/experiments/simulated_pipeline.py b/gallery/experiments/simulated_pipeline.py index 657cbd3d3c..cb6f6d2ba9 100644 --- a/gallery/experiments/simulated_pipeline.py +++ b/gallery/experiments/simulated_pipeline.py @@ -40,10 +40,11 @@ interactive = False # Do we want to draw blocking interactive plots? do_cov2d = True # Use CWF coefficients n_imgs = None # Set to None for all images in starfile -img_size = 64 # Downsample the images/reconstruction to a desired resolution +img_size = 77 # Downsample the images/reconstruction to a desired resolution n_classes = 2000 # How many class averages to compute. -n_nbor = 50 # How many neighbors to stack +n_nbor = 100 # How many neighbors to stack starfile_in = "10028/data/shiny_2sets.star" +volume_filename_prefix_out = f"10028_recon_{num_imgs}_c{n_classes}_m{n_nbor}_{img_size}.mrc" pixel_size = 1.34 # Create a source object for the experimental images @@ -159,8 +160,7 @@ # Perform the estimation and save the volume. estimated_volume = estimator.estimate() -fn = f"estimated_volume_n{num_imgs}_c{n_classes}_m{n_nbor}_{img_size}.mrc" -estimated_volume.save(fn, overwrite=True) +estimated_volume.save(volume_filename_prefix_out, overwrite=True) # Peek at result if interactive: From 7bafbedbd64d21fffa48ee460d57184789627c32 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 8 Feb 2022 12:55:48 -0500 Subject: [PATCH 14/40] linter/syntax cleanup --- gallery/experiments/simulated_pipeline.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/gallery/experiments/simulated_pipeline.py b/gallery/experiments/simulated_pipeline.py index cb6f6d2ba9..fb2c2cef58 100644 --- a/gallery/experiments/simulated_pipeline.py +++ b/gallery/experiments/simulated_pipeline.py @@ -24,10 +24,8 @@ from aspire.classification import BFSReddyChatterjiAlign2D, RIRClass2D from aspire.denoising import DenoiserCov2D from aspire.noise import AnisotropicNoiseEstimator -from aspire.operators import FunctionFilter, RadialCTFFilter from aspire.reconstruction import MeanEstimator from aspire.source import RelionSource -from aspire.volume import Volume logger = logging.getLogger(__name__) @@ -39,18 +37,16 @@ interactive = False # Do we want to draw blocking interactive plots? do_cov2d = True # Use CWF coefficients -n_imgs = None # Set to None for all images in starfile +n_imgs = None # Set to None for all images in starfile, can set smaller for tests. img_size = 77 # Downsample the images/reconstruction to a desired resolution n_classes = 2000 # How many class averages to compute. n_nbor = 100 # How many neighbors to stack starfile_in = "10028/data/shiny_2sets.star" -volume_filename_prefix_out = f"10028_recon_{num_imgs}_c{n_classes}_m{n_nbor}_{img_size}.mrc" +volume_filename_prefix_out = f"10028_recon_c{n_classes}_m{n_nbor}_{img_size}.mrc" pixel_size = 1.34 # Create a source object for the experimental images -src = RelionSource( - starfile_in, pixel_size=pixel_size, max_rows=n_imgs -) +src = RelionSource(starfile_in, pixel_size=pixel_size, max_rows=n_imgs) # Downsample the images logger.info(f"Set the resolution to {img_size} X {img_size}") From afca0ea5d71a55b0c8a49ff80f2e60faa6144343 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 8 Feb 2022 13:24:33 -0500 Subject: [PATCH 15/40] cleanup examples for sphinx gallery --- ...e.py => experimental_abinitio_pipeline.py} | 63 ++++++++++++++----- .../simulated_abinitio_pipeline.py | 50 +++++++++------ 2 files changed, 76 insertions(+), 37 deletions(-) rename gallery/experiments/{simulated_pipeline.py => experimental_abinitio_pipeline.py} (67%) diff --git a/gallery/experiments/simulated_pipeline.py b/gallery/experiments/experimental_abinitio_pipeline.py similarity index 67% rename from gallery/experiments/simulated_pipeline.py rename to gallery/experiments/experimental_abinitio_pipeline.py index fb2c2cef58..3a6e949afa 100644 --- a/gallery/experiments/simulated_pipeline.py +++ b/gallery/experiments/experimental_abinitio_pipeline.py @@ -1,18 +1,26 @@ """ -ASPIRE-Python Abinitio Pipeline -================================ +Abinitio Pipeline - Experimental Data +===================================== -In this notebook we will introduce a selection of -components corresponding to a pipeline using -the EMD-10028 picked particles dataset. +This notebook introduces a selection of +components corresponding to loading real Relion picked +particle Cryo-EM data and running key ASPIRE-Python +Abinitio model components as a pipeline. + +Specifically in this pipeline uses the +EMPIAR 10028 picked particles data, available here: + +https://www.ebi.ac.uk/empiar/EMPIAR-10028 + +https://www.ebi.ac.uk/emdb/EMD-10028 """ # %% # Imports # ------- -# First we import some of the usual suspects. -# In addition, we import some classes from -# the ASPIRE package that we will use throughout this experiment. +# First import some of the usual suspects. +# In addition, import some classes from +# the ASPIRE package that will be used throughout this experiment. import logging @@ -35,7 +43,7 @@ # --------------- # Example simulation configuration. -interactive = False # Do we want to draw blocking interactive plots? +interactive = False # Draw blocking interactive plots? do_cov2d = True # Use CWF coefficients n_imgs = None # Set to None for all images in starfile, can set smaller for tests. img_size = 77 # Downsample the images/reconstruction to a desired resolution @@ -45,6 +53,14 @@ volume_filename_prefix_out = f"10028_recon_c{n_classes}_m{n_nbor}_{img_size}.mrc" pixel_size = 1.34 +# %% +# Source data and Preprocessing +# ----------------------------- +# +# `RelionSource` is used to access the experimental data via a `starfile`. +# Begin by downsampling to our chosen resolution, then preprocess +# to correct for CTF and noise. + # Create a source object for the experimental images src = RelionSource(starfile_in, pixel_size=pixel_size, max_rows=n_imgs) @@ -56,11 +72,11 @@ if interactive: src.images(0, 10).show() -# Currently we use phase_flip to attempt correcting for CTF. +# Use phase_flip to attempt correcting for CTF. logger.info("Perform phase flip to input images.") src.phase_flip() -# We should estimate the noise and `Whiten` based on the estimated noise +# Estimate the noise and `Whiten` based on the estimated noise aiso_noise_estimator = AnisotropicNoiseEstimator(src) src.whiten(aiso_noise_estimator.filter) @@ -73,12 +89,25 @@ if interactive: src.images(0, 10).show() -# # Optionally invert image contrast, depends on data. +# # Optionally invert image contrast, depends on data convention. +# # This is not needed for 10028, but included anyway. # logger.info("Invert the global density contrast") # src.invert_contrast() -# On Simulation data, better results so far were achieved without cov2d -# However, we can demonstrate using CWF denoised images for classification. +# %% +# Optional: CWF Denoising +# ----------------------- +# +# Optionally generate an alternative source that is denoised with `cov2d`, +# then configure a customized aligner. This allows the use of CWF denoised +# images for classification, but stacks the original images for averages +# used in the remainder of the reconstruction pipeline. +# +# In this example, this behavior is controlled by the `do_cov2d` boolean variable. +# When disabled, the original src and default aligner is used. +# If you will not be using cov2d, +# you may remove this code block and associated variables. + classification_src = src custom_aligner = None if do_cov2d: @@ -101,7 +130,7 @@ # Class Averaging # ---------------------- # -# Now we perform classification and averaging for each class. +# Now perform classification and averaging for each class. logger.info("Begin Class Averaging") @@ -118,7 +147,7 @@ ) classes, reflections, distances = rir.classify() -# Only care about the averages returned right now. +# Only care about the averages returned right now (index 0) avgs = rir.averages(classes, reflections, distances)[0] if interactive: avgs.images(0, 10).show() @@ -127,7 +156,7 @@ # Common Line Estimation # ---------------------- # -# Now we can create a CL instance for estimating orientation of projections +# Next create a CL instance for estimating orientation of projections # using the Common Line with Synchronization Voting method. logger.info("Begin Orientation Estimation") diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index 59e6aca25f..38b5d5a779 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -1,17 +1,19 @@ """ -ASPIRE-Python Abinitio Pipeline -================================ +Abinitio Pipeline - Simulated Data +================================== -In this notebook we will introduce a selection of -components corresponding to a pipeline. +This notebook introduces a selection of +components corresponding to generating realistic +simulated Cryo-EM data and running key ASPIRE-Python +Abinitio model components as a pipeline. """ # %% # Imports # ------- -# First we import some of the usual suspects. -# In addition, we import some classes from -# the ASPIRE package that we will use throughout this experiment. +# First import some of the usual suspects. +# In addition, import some classes from +# the ASPIRE package that will be used throughout this experiment. import logging @@ -44,7 +46,7 @@ # Medium sim: img_size 64, num_imgs 20000, n_classes 2000, n_nbor 10 # Large sim: img_size 129, num_imgs 30000, n_classes 2000, n_nbor 20 -interactive = True # Do we want to draw blocking interactive plots? +interactive = True # Draw blocking interactive plots? do_cov2d = False # Use CWF coefficients img_size = 32 # Downsample the volume to a desired resolution num_imgs = 10000 # How many images in our source. @@ -56,7 +58,7 @@ # %% # Simulation Data # --------------- -# We'll start with a fairly hi-res volume available from EMPIAR/EMDB. +# Start with a fairly hi-res volume available from EMPIAR/EMDB. # https://www.ebi.ac.uk/emdb/EMD-2660 # https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-2660/map/emd_2660.map.gz og_v = Volume.load("emd_2660.map", dtype=np.float64) @@ -109,11 +111,11 @@ def noise_function(x, y): if interactive: src.images(0, 10).show() -# Currently we use phase_flip to attempt correcting for CTF. +# Use phase_flip to attempt correcting for CTF. logger.info("Perform phase flip to input images.") src.phase_flip() -# We should estimate the noise and `Whiten` based on the estimated noise +# Estimate the noise and `Whiten` based on the estimated noise aiso_noise_estimator = AnisotropicNoiseEstimator(src) src.whiten(aiso_noise_estimator.filter) @@ -126,15 +128,23 @@ def noise_function(x, y): if interactive: src.images(0, 10).show() -# # Optionally invert image contrast, depends on data. -# logger.info("Invert the global density contrast") -# src.invert_contrast() - # Cache to memory for some speedup src = ArrayImageSource(src.images(0, num_imgs).asnumpy(), angles=src.angles) -# On Simulation data, better results so far were achieved without cov2d -# However, we can demonstrate using CWF denoised images for classification. +# %% +# Optional: CWF Denoising +# ----------------------- +# +# Optionally generate an alternative source that is denoised with `cov2d`, +# then configure a customized aligner. This allows the use of CWF denoised +# images for classification, but stacks the original images for averages +# used in the remainder of the reconstruction pipeline. +# +# In this example, this behavior is controlled by the `do_cov2d` boolean variable. +# When disabled, the original src and default aligner is used. +# If you will not be using cov2d, +# you may remove this code block and associated variables. + classification_src = src custom_aligner = None if do_cov2d: @@ -157,7 +167,7 @@ def noise_function(x, y): # Class Averaging # ---------------------- # -# Now we perform classification and averaging for each class. +# Now perform classification and averaging for each class. logger.info("Begin Class Averaging") @@ -174,7 +184,7 @@ def noise_function(x, y): ) classes, reflections, distances = rir.classify() -# Only care about the averages returned right now. +# Only care about the averages returned right now (index 0) avgs = rir.averages(classes, reflections, distances)[0] if interactive: avgs.images(0, 10).show() @@ -183,7 +193,7 @@ def noise_function(x, y): # Common Line Estimation # ---------------------- # -# Now we can create a CL instance for estimating orientation of projections +# Next create a CL instance for estimating orientation of projections # using the Common Line with Synchronization Voting method. logger.info("Begin Orientation Estimation") From 570212582c984af14bcef060f5db9e62a7899906 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 8 Feb 2022 13:26:50 -0500 Subject: [PATCH 16/40] typo --- gallery/experiments/experimental_abinitio_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/experiments/experimental_abinitio_pipeline.py b/gallery/experiments/experimental_abinitio_pipeline.py index 3a6e949afa..ff9293f328 100644 --- a/gallery/experiments/experimental_abinitio_pipeline.py +++ b/gallery/experiments/experimental_abinitio_pipeline.py @@ -7,7 +7,7 @@ particle Cryo-EM data and running key ASPIRE-Python Abinitio model components as a pipeline. -Specifically in this pipeline uses the +Specifically this pipeline uses the EMPIAR 10028 picked particles data, available here: https://www.ebi.ac.uk/empiar/EMPIAR-10028 From 27586925ab0ce2f22a3d8dc6fb5dae32e90fbefd Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 8 Feb 2022 17:41:38 -0500 Subject: [PATCH 17/40] less aggresive example settings for now --- gallery/experiments/experimental_abinitio_pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gallery/experiments/experimental_abinitio_pipeline.py b/gallery/experiments/experimental_abinitio_pipeline.py index ff9293f328..e137b4546c 100644 --- a/gallery/experiments/experimental_abinitio_pipeline.py +++ b/gallery/experiments/experimental_abinitio_pipeline.py @@ -45,10 +45,10 @@ interactive = False # Draw blocking interactive plots? do_cov2d = True # Use CWF coefficients -n_imgs = None # Set to None for all images in starfile, can set smaller for tests. -img_size = 77 # Downsample the images/reconstruction to a desired resolution -n_classes = 2000 # How many class averages to compute. -n_nbor = 100 # How many neighbors to stack +n_imgs = 20000 # Set to None for all images in starfile, can set smaller for tests. +img_size = 32 # Downsample the images/reconstruction to a desired resolution +n_classes = 1000 # How many class averages to compute. +n_nbor = 50 # How many neighbors to stack starfile_in = "10028/data/shiny_2sets.star" volume_filename_prefix_out = f"10028_recon_c{n_classes}_m{n_nbor}_{img_size}.mrc" pixel_size = 1.34 From fb684ea5db36c6438a760536fa203e417f5c2d45 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 11 Feb 2022 11:57:25 -0500 Subject: [PATCH 18/40] rm import comment --- src/aspire/classification/align2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index 369718901f..7d0157e7ed 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -6,7 +6,6 @@ import numpy as np from skimage.filters import difference_of_gaussians, window -# import skimage.io from skimage.transform import rotate, warp_polar from tqdm import tqdm, trange From 028d062d0492eda31dde224258346edef6f0943d Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 11 Feb 2022 12:13:57 -0500 Subject: [PATCH 19/40] try to use the internal source cache --- gallery/experiments/simulated_abinitio_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index 38b5d5a779..8b20c53a90 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -129,7 +129,7 @@ def noise_function(x, y): src.images(0, 10).show() # Cache to memory for some speedup -src = ArrayImageSource(src.images(0, num_imgs).asnumpy(), angles=src.angles) +src.cache() # %% # Optional: CWF Denoising From 0d1408c1d5af1f8e18515c3b0d7c0dd239cc5ef2 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 11 Feb 2022 12:21:32 -0500 Subject: [PATCH 20/40] use ASPIRE fft module --- src/aspire/classification/align2d.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index 7d0157e7ed..fb3fa01bb4 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -10,6 +10,7 @@ from tqdm import tqdm, trange from aspire.image import Image +from aspire.numeric import fft from aspire.source import ArrayImageSource from aspire.utils.coor_trans import grid_2d @@ -462,14 +463,14 @@ def _phase_cross_correlation(self, img0, img1): # Cache img0 transform, this saves n_classes*(n_nbor-1) transforms # Note we use the `id` because ndarray are unhashable - src_f = self.__cache.setdefault(id(img0), np.fft.fft2(img0)) + src_f = self.__cache.setdefault(id(img0), fft.fft2(img0)) - target_f = np.fft.fft2(img1) + target_f = fft.fft2(img1) # Whole-pixel shifts - Compute cross-correlation by an IFFT shape = src_f.shape image_product = src_f * target_f.conj() - cross_correlation = np.fft.ifft2(image_product) + cross_correlation = fft.ifft2(image_product) # Locate maximum maxima = np.unravel_index( @@ -529,7 +530,7 @@ def _reddychatterji( # Window Images (Fix spectral boundary) wfixed_img = fixed_img_dog * window("hann", fixed_img.shape) # Transform image to Fourier space - fixed_img_fs = np.abs(np.fft.fftshift(np.fft.fft2(wfixed_img))) ** 2 + fixed_img_fs = np.abs(fft.fftshift(fft.fft2(wfixed_img))) ** 2 # Compute Log Polar Transform radius = fixed_img_fs.shape[0] // 8 # Low Pass warped_fixed_img_fs = warp_polar( @@ -566,7 +567,7 @@ def _reddychatterji( ) # Transform image to Fourier space - regis_img_fs = np.abs(np.fft.fftshift(np.fft.fft2(wregis_img))) ** 2 + regis_img_fs = np.abs(fft.fftshift(fft.fft2(wregis_img))) ** 2 self._windowed_psd_diagnostic( classes[k][0], fixed_img_fs, classes[k][m], regis_img_fs From 2b485a84be328a9d887668db1a34fa8f29822680 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 11 Feb 2022 13:00:44 -0500 Subject: [PATCH 21/40] missing abcs --- src/aspire/classification/class2d.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aspire/classification/class2d.py b/src/aspire/classification/class2d.py index 67293031a3..9df27fa0d4 100644 --- a/src/aspire/classification/class2d.py +++ b/src/aspire/classification/class2d.py @@ -1,5 +1,5 @@ import logging -from abc import ABC +from abc import ABC, abstractmethod import numpy as np @@ -43,6 +43,7 @@ def __init__( self.n_classes = n_classes self.seed = seed + @abstractmethod def classify(self): """ Classify the images from Source into classes with similar viewing angles. @@ -50,6 +51,7 @@ def classify(self): Returns classes and associated metadata (classes, reflections, distances) """ + @abstractmethod def averages(self, classes, refl, distances): """ Returns class averages using prescribed `aligner`. From cd09537f52186ff9355bbb1f52612a990cf2f530 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 11 Feb 2022 13:00:53 -0500 Subject: [PATCH 22/40] update docstring --- src/aspire/classification/rir_class2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index c1aeaaa6c7..a89ec994d3 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -60,7 +60,7 @@ def __init__( :param large_pca_implementation: See `pca`. :param nn_implementation: See `nn_classification`. :param bispectrum_implementation: See `bispectrum`. - :param aligner: An Align2D subclass. Defaults to BFRAlign2D. + :param aligner: An Align2D subclass. Defaults to BFSReddyChatterjiAlign2D. :param dtype: Optional dtype, otherwise taken from src. :param seed: Optional RNG seed to be passed to random methods, (example Random NN). :return: RIRClass2D instance to be used to compute bispectrum-like rotationally invariant 2D classification. From f3d2f9fedfd35d777bb54fcaa8822a93959e4fc0 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 11 Feb 2022 13:21:06 -0500 Subject: [PATCH 23/40] Remove UT for ABC --- tests/test_class2D.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/test_class2D.py b/tests/test_class2D.py index a7aea3fd3c..69cd20a38b 100644 --- a/tests/test_class2D.py +++ b/tests/test_class2D.py @@ -7,7 +7,7 @@ from sklearn import datasets from aspire.basis import FFBBasis2D, FSPCABasis -from aspire.classification import BFSRAlign2D, Class2D, RIRClass2D +from aspire.classification import BFSRAlign2D, RIRClass2D from aspire.classification.legacy_implementations import bispec_2drot_large, pca_y from aspire.operators import ScalarFilter from aspire.source import Simulation @@ -139,14 +139,6 @@ def setUp(self): # Ceate another fspca_basis, use autogeneration FFB2D Basis self.noisy_fspca_basis = FSPCABasis(self.noisy_src) - def testClass2DBase(self): - """ - Make sure the base class doesn't crash when using arguments. - """ - _ = Class2D(self.clean_src) # Default dtype - _ = Class2D(self.clean_src, dtype=self.dtype) # Consistent dtype - _ = Class2D(self.clean_src, dtype=np.float16) # Different dtype - def testIncorrectComponents(self): """ Check we raise with inconsistent configuration of FSPCA components. From d534847a5d7663efc518b83f0670daa2bb899632 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 11 Feb 2022 13:50:04 -0500 Subject: [PATCH 24/40] unused import cleanup --- gallery/experiments/simulated_abinitio_pipeline.py | 2 +- src/aspire/classification/align2d.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index 8b20c53a90..514640af49 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -27,7 +27,7 @@ from aspire.noise import AnisotropicNoiseEstimator from aspire.operators import FunctionFilter, RadialCTFFilter from aspire.reconstruction import MeanEstimator -from aspire.source import ArrayImageSource, Simulation +from aspire.source import Simulation from aspire.utils.coor_trans import ( get_aligned_rotations, get_rots_mse, diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index fb3fa01bb4..18f2b070b1 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -5,7 +5,6 @@ import matplotlib.pyplot as plt import numpy as np from skimage.filters import difference_of_gaussians, window - from skimage.transform import rotate, warp_polar from tqdm import tqdm, trange From 90322c8b2b9dc11058194889d6238c9d72550b85 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 14 Feb 2022 08:30:23 -0500 Subject: [PATCH 25/40] use basis expand method for image coef --- src/aspire/basis/basis.py | 8 +++++++- src/aspire/classification/align2d.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/aspire/basis/basis.py b/src/aspire/basis/basis.py index ddafca620e..314a4e9ffe 100644 --- a/src/aspire/basis/basis.py +++ b/src/aspire/basis/basis.py @@ -4,6 +4,8 @@ from scipy.sparse.linalg import LinearOperator, cg from aspire.basis.basis_utils import num_besselj_zeros +from aspire.image import Image +from aspire.volume import Volume from aspire.utils import ensure, mdim_mat_fun_conj from aspire.utils.matlab_compat import m_reshape @@ -174,6 +176,10 @@ def expand(self, x): those first dimensions of `x`. """ + + if isinstance(x, Image) or isinstance(x,Volume): + x = x.asnumpy() + # ensure the first dimensions with size of self.sz sz_roll = x.shape[: -self.ndim] @@ -206,5 +212,5 @@ def expand(self, x): raise RuntimeError("Unable to converge!") # return v coefficients with the last dimension of self.count - v = v.reshape((-1, *sz_roll)) + v = v.reshape((*sz_roll, -1)) return v diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index 18f2b070b1..a324787fb3 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -159,7 +159,7 @@ def average( if shifts is not None: neighbors_imgs.shift(shifts[i]) - neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) + neighbors_coefs = self.composite_basis.expand(neighbors_imgs) else: # Get the neighbors neighbors_ids = classes[i] @@ -699,7 +699,7 @@ def average( if coefs is None: # Retrieve relavent images directly from source. neighbors_imgs = Image(self._cls_images(classes[i])) - neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) + neighbors_coefs = self.composite_basis.expand(neighbors_imgs) else: # Get the neighbors neighbors_ids = classes[i] From 28edbf99b6085b6cfd5e06f50123ca2d8b7dd115 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 14 Feb 2022 08:31:50 -0500 Subject: [PATCH 26/40] remove output_nn_filename --- src/aspire/classification/rir_class2d.py | 57 ------------------------ 1 file changed, 57 deletions(-) diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index a89ec994d3..13ba624eab 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -28,7 +28,6 @@ def __init__( bispectrum_freq_cutoff=None, large_pca_implementation="legacy", nn_implementation="legacy", - output_nn_filename=None, bispectrum_implementation="legacy", aligner=None, dtype=None, @@ -120,7 +119,6 @@ def __init__( f"Provided nn_implementation={nn_implementation} not in {nn_implementations.keys()}" ) self._nn_classification = nn_implementations[nn_implementation] - self.output_nn_filename = output_nn_filename # # Do we have a sane Large Dataset PCA large_pca_implementations = { @@ -187,8 +185,6 @@ def classify(self, diagnostics=False): # # Stage 2: Compute Nearest Neighbors logger.info("Calculate Nearest Neighbors") classes, reflections, distances = self.nn_classification(coef_b, coef_b_r) - if self.output_nn_filename is not None: - self._save_nn(classes, reflections, distances) if diagnostics: # Lets peek at the distribution of distances @@ -369,59 +365,6 @@ def _legacy_nn_classification(self, coeff_b, coeff_b_r, batch_size=2000): return classes, refl, distances - def _save_nn(self, classes, reflections, distances): - """ - Output the Nearest Neighbors graph as a weighted adjacency list. - - Vertices are indexed by their natural index in `source`. - Note reflected images are represented by `index + src.n`. - - Only the output of the Nearest Neighbor call is saved. - If you want a complete graph, specify 2*src.n neighbors, - that is all images and their reflections. - - Because this is mixed datatypes (int and floating), - this will be output as a space delimited text file. - - Vi1 Vj1 W_i1_j1 Vj2 Wi1_j2 ... - Vi2 Vj1 W_i2_j1 Vj2 Wi2_j2 ... - ... - - """ - - # Construct the weighted adjacency list - AdjList = [] - for k in range(len(classes)): - - row = [] - vik = classes[k][0] - row.append(vik) - - for j in range(1, len(classes[k])): - - # Neighbor index - vj = classes[k][j] - if reflections[k][j]: - vj += self.src.n - row.append(vj) - - # Neighbor Weight (distance) - wt = distances[k][j] - row.append(wt) - - # Store this row of the AdjList - AdjList.append(row) - - logger.info( - "Writing Nearest Neighbors as Weighted Adjacency List" - f" to {self.output_nn_filename}" - ) - - # Output - with open(self.output_nn_filename, "w") as fh: - for row in AdjList: - fh.write(" ".join(str(x) for x in row) + "\n") - def _legacy_pca(self, M): """ This is more or less the historic implementation ported From 6ce830295b59e85ab364b27d3c9dfa1dbbd50e36 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 14 Feb 2022 09:34:13 -0500 Subject: [PATCH 27/40] change _reddy ... util method args from in place to slices --- src/aspire/basis/basis.py | 4 +- src/aspire/classification/align2d.py | 65 ++++++++++++++++++---------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/aspire/basis/basis.py b/src/aspire/basis/basis.py index 314a4e9ffe..0b2fe01f5d 100644 --- a/src/aspire/basis/basis.py +++ b/src/aspire/basis/basis.py @@ -5,9 +5,9 @@ from aspire.basis.basis_utils import num_besselj_zeros from aspire.image import Image -from aspire.volume import Volume from aspire.utils import ensure, mdim_mat_fun_conj from aspire.utils.matlab_compat import m_reshape +from aspire.volume import Volume logger = logging.getLogger(__name__) @@ -177,7 +177,7 @@ def expand(self, x): """ - if isinstance(x, Image) or isinstance(x,Volume): + if isinstance(x, Image) or isinstance(x, Volume): x = x.asnumpy() # ensure the first dimensions with size of self.sz diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py index a324787fb3..f3fca3e2c1 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/align2d.py @@ -503,23 +503,38 @@ def _align(self, classes, reflections, basis_coefficients): # # Get the array of images for this class, using the `alignment_src`. images = self._cls_images(classes[k], src=self.alignment_src) - self._reddychatterji( - k, images, classes, reflections, rotations, correlations, shifts + rotations[k], correlations[k], shifts[k] = self._reddychatterji( + images, classes[k], reflections[k] ) return classes, reflections, rotations, shifts, correlations - def _reddychatterji( - self, k, images, classes, reflections, rotations, correlations, shifts - ): + def _reddychatterji(self, images, class_k, reflection_k): """ Compute the Reddy Chatterji method registering images[1:] to image[0]. This differs from papers and published scikit implimentations by computing the fixed base image[0] pipeline once then reusing. + + This is a util function to help loop over `classes`. + + :param images: Image data + :param class_k: Image indices + :param reflection_k: Image reflections + :returns: (rotations_k, correlations_k, shifts_k) corresponding to `images` """ - # De-Mean + # Result arrays + M = len(images) + rotations_k = np.empty(M, dtype=self.dtype) + correlations_k = np.empty(M, dtype=self.dtype) + shifts_k = np.empty((M, 2), dtype=self.dtype) + # Initialize for Image 0, others will populate in loop. + rotations_k[0] = 0 + correlations_k[0] = 0 + shifts_k[0] = 0 + + # De-Mean, note images is mutated and should be a `copy`. images -= images.mean(axis=(-1, -2))[:, np.newaxis, np.newaxis] # Precompute fixed_img data used repeatedly in the loop below. @@ -552,7 +567,7 @@ def _reddychatterji( regis_img = images[m] # Reflect images when necessary - if reflections[k][m]: + if reflection_k[m]: regis_img = np.flipud(regis_img) # Difference of Gaussians (Band Filter) @@ -562,14 +577,14 @@ def _reddychatterji( wregis_img = regis_img_dog * window("hann", regis_img.shape) self._input_images_diagnostic( - classes[k][0], wfixed_img, classes[k][m], wregis_img + class_k[0], wfixed_img, class_k[m], wregis_img ) # Transform image to Fourier space regis_img_fs = np.abs(fft.fftshift(fft.fft2(wregis_img))) ** 2 self._windowed_psd_diagnostic( - classes[k][0], fixed_img_fs, classes[k][m], regis_img_fs + class_k[0], fixed_img_fs, class_k[m], regis_img_fs ) # Compute Log Polar Transform @@ -581,7 +596,7 @@ def _reddychatterji( ) self._log_polar_diagnostic( - classes[k][0], warped_fixed_img_fs, classes[k][m], warped_regis_img_fs + class_k[0], warped_fixed_img_fs, class_k[m], warped_regis_img_fs ) # Only use half of FFT, because it's symmetrical @@ -624,16 +639,16 @@ def _reddychatterji( r += 180 self._rotated_diagnostic( - classes[k][0], + class_k[0], fixed_img, - classes[k][m], + class_k[m], regis_img_estimated, - reflections[k][m], + reflection_k[m], r, ) # Assign estimated rotations results - rotations[k][m] = -r * np.pi / 180 # Reverse rot and convert to radians + rotations_k[m] = -r * np.pi / 180 # Reverse rot and convert to radians if self.do_cross_corr_translations: # Prepare for searching over translations using cross-correlation with the rotated image. @@ -650,14 +665,14 @@ def _reddychatterji( regis_img_estimated = np.roll(regis_img_estimated, shift_y, axis=0) regis_img_estimated = np.roll(regis_img_estimated, shift_x, axis=1) # Assign estimated shift to results - shifts[k][m] = shift[::-1].astype(int) + shifts_k[m] = shift[::-1].astype(int) self._averaged_diagnostic( - classes[k][0], + class_k[0], fixed_img, - classes[k][m], + class_k[m], regis_img_estimated, - reflections[k][m], + reflection_k[m], r, ) else: @@ -665,18 +680,20 @@ def _reddychatterji( # Estimated `corr` metric corr = np.dot(fixed_img.flatten(), regis_img_estimated.flatten()) - correlations[k][m] = corr + correlations_k[m] = corr logger.debug( - f"Class {k}, ref {classes[k][0]}, Neighbor {m} Index {classes[k][m]}" + f"ref {class_k[0]}, Neighbor {m} Index {class_k[m]}" f" Estimates: {r}*, Shift: {shift}," - f" Corr: {corr}, Refl?: {reflections[k][m]}" + f" Corr: {corr}, Refl?: {reflection_k[m]}" ) # Cleanup some cached stuff for this class self.__cache.pop(id(warped_fixed_img_fs), None) self.__cache.pop(id(twfixed_img), None) + return rotations_k, correlations_k, shifts_k + def average( self, classes, @@ -944,12 +961,14 @@ def _align(self, classes, reflections, basis_coefficients): s = np.array([xs, ys]) # Get the array of images for this class + # Note we mutate `images` here with shifting, + # then later in `_reddychatterji` images = unshifted_images.copy() # Don't shift the base image images[1:] = Image(unshifted_images[1:]).shift(s).asnumpy() - self._reddychatterji( - k, images, classes, reflections, _rotations, _correlations, _shifts + rotations[k], correlations[k], shifts[k] = self._reddychatterji( + images, classes[k], reflections[k] ) # Where corr has improved From fa91a30bdc30ea72b2a5e56105a1457ad075f908 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 14 Feb 2022 12:46:47 -0500 Subject: [PATCH 28/40] Refactor aligner~>averager --- .../experimental_abinitio_pipeline.py | 15 +- .../simulated_abinitio_pipeline.py | 11 +- gallery/tutorials/class_averaging.py | 4 +- src/aspire/classification/__init__.py | 18 +- .../{align2d.py => averager2d.py} | 155 +++++++++--------- src/aspire/classification/rir_class2d.py | 26 +-- tests/test_align2d.py | 64 ++++---- tests/test_class2D.py | 5 +- 8 files changed, 142 insertions(+), 156 deletions(-) rename src/aspire/classification/{align2d.py => averager2d.py} (92%) diff --git a/gallery/experiments/experimental_abinitio_pipeline.py b/gallery/experiments/experimental_abinitio_pipeline.py index e137b4546c..7e57fdbef8 100644 --- a/gallery/experiments/experimental_abinitio_pipeline.py +++ b/gallery/experiments/experimental_abinitio_pipeline.py @@ -29,7 +29,7 @@ from aspire.abinitio import CLSyncVoting from aspire.basis import FFBBasis2D, FFBBasis3D -from aspire.classification import BFSReddyChatterjiAlign2D, RIRClass2D +from aspire.classification import BFSReddyChatterjiAverager2D, RIRClass2D from aspire.denoising import DenoiserCov2D from aspire.noise import AnisotropicNoiseEstimator from aspire.reconstruction import MeanEstimator @@ -99,17 +99,17 @@ # ----------------------- # # Optionally generate an alternative source that is denoised with `cov2d`, -# then configure a customized aligner. This allows the use of CWF denoised +# then configure a customized averager. This allows the use of CWF denoised # images for classification, but stacks the original images for averages # used in the remainder of the reconstruction pipeline. # # In this example, this behavior is controlled by the `do_cov2d` boolean variable. -# When disabled, the original src and default aligner is used. +# When disabled, the original src and default averager is used. # If you will not be using cov2d, # you may remove this code block and associated variables. classification_src = src -custom_aligner = None +custom_averager = None if do_cov2d: # Use CWF denoising cwf_denoiser = DenoiserCov2D(src) @@ -121,7 +121,7 @@ # Use regular `src` for the alignment and composition (averaging). composite_basis = FFBBasis2D((src.L,) * 2, dtype=src.dtype) - custom_aligner = BFSReddyChatterjiAlign2D( + custom_averager = BFSReddyChatterjiAverager2D( None, src, composite_basis, dtype=src.dtype ) @@ -143,12 +143,11 @@ large_pca_implementation="legacy", nn_implementation="sklearn", bispectrum_implementation="legacy", - aligner=custom_aligner, + averager=custom_averager, ) classes, reflections, distances = rir.classify() -# Only care about the averages returned right now (index 0) -avgs = rir.averages(classes, reflections, distances)[0] +avgs = rir.averages(classes, reflections, distances) if interactive: avgs.images(0, 10).show() diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index 514640af49..a93807687b 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -22,7 +22,7 @@ from aspire.abinitio import CLSyncVoting from aspire.basis import FFBBasis2D, FFBBasis3D -from aspire.classification import BFSReddyChatterjiAlign2D, RIRClass2D +from aspire.classification import BFSReddyChatterjiAverager2D, RIRClass2D from aspire.denoising import DenoiserCov2D from aspire.noise import AnisotropicNoiseEstimator from aspire.operators import FunctionFilter, RadialCTFFilter @@ -146,7 +146,7 @@ def noise_function(x, y): # you may remove this code block and associated variables. classification_src = src -custom_aligner = None +custom_averager = None if do_cov2d: # Use CWF denoising cwf_denoiser = DenoiserCov2D(src) @@ -158,7 +158,7 @@ def noise_function(x, y): # Use regular `src` for the alignment and composition (averaging). composite_basis = FFBBasis2D((src.L,) * 2, dtype=src.dtype) - custom_aligner = BFSReddyChatterjiAlign2D( + custom_averager = BFSReddyChatterjiAverager2D( None, src, composite_basis, dtype=src.dtype ) @@ -180,12 +180,11 @@ def noise_function(x, y): large_pca_implementation="legacy", nn_implementation="sklearn", bispectrum_implementation="legacy", - aligner=custom_aligner, + averager=custom_averager, ) classes, reflections, distances = rir.classify() -# Only care about the averages returned right now (index 0) -avgs = rir.averages(classes, reflections, distances)[0] +avgs = rir.averages(classes, reflections, distances) if interactive: avgs.images(0, 10).show() diff --git a/gallery/tutorials/class_averaging.py b/gallery/tutorials/class_averaging.py index 7d973e33c8..5bebccbe5d 100644 --- a/gallery/tutorials/class_averaging.py +++ b/gallery/tutorials/class_averaging.py @@ -117,9 +117,7 @@ ) classes, reflections, dists = rir.classify() -avgs, classes, reflections, rotations, shifts, corrs = rir.averages( - classes, reflections, dists -) +avgs = rir.averages(classes, reflections, dists) # %% # Display Classes diff --git a/src/aspire/classification/__init__.py b/src/aspire/classification/__init__.py index 21fbde28e4..b1aea6ea5a 100644 --- a/src/aspire/classification/__init__.py +++ b/src/aspire/classification/__init__.py @@ -1,12 +1,12 @@ -from .align2d import ( - Align2D, - AveragedAlign2D, - BFRAlign2D, - BFSRAlign2D, - BFSReddyChatterjiAlign2D, - EMAlign2D, - FTKAlign2D, - ReddyChatterjiAlign2D, +from .averager2d import ( + AligningAverager2D, + Averager2D, + BFRAverager2D, + BFSRAverager2D, + BFSReddyChatterjiAverager2D, + EMAverager2D, + FTKAverager2D, + ReddyChatterjiAverager2D, ) from .class2d import Class2D from .rir_class2d import RIRClass2D diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/averager2d.py similarity index 92% rename from src/aspire/classification/align2d.py rename to src/aspire/classification/averager2d.py index f3fca3e2c1..4f6ea862cd 100644 --- a/src/aspire/classification/align2d.py +++ b/src/aspire/classification/averager2d.py @@ -16,9 +16,9 @@ logger = logging.getLogger(__name__) -class Align2D(ABC): +class Averager2D(ABC): """ - Base class for 2D Image Alignment methods. + Base class for 2D Image Averaging methods. """ def __init__( @@ -58,35 +58,24 @@ def __init__( ) @abstractmethod - def align(self, classes, reflections, basis_coefficients): + def average( + self, + classes, + reflections, + coefs=None, + ): """ - Any align2D alignment method should take in the below arguments - and return aligned images. - - During this process `rotations`, `reflections`, `shifts` and - `correlations` properties will be computed for aligners - that implement them. Some future aligners (example. EM based) - may not produce these intermediates. - - `rotations` is an (n_classes, n_nbor) array of angles, - which should represent the rotations needed to align images within - that class. `rotations` is measured in Radians. - - `correlations` is an (n_classes, n_nbor) array representing - a correlation like measure between classified images and their base - image (image index 0). - - `shifts` is None or an (n_classes, n_nbor) array of 2D shifts - which should represent the translation needed to best align the images - within that class. + Combines images using stacking in `self.composite_basis`. - Subclasses of `align` should extend this method with optional arguments. + Subclasses should implement this. + (Example EM algos use radically different averaging). - :param classes: (n_classes, n_nbor) integer array of img indices - :param reflections: (n_classes, n_nbor) bool array of corresponding reflections - :param basis_coefficients: (n_img, self.pca_basis.count) compressed basis coefficients + Should return an Image source of synthetic class averages. - :returns: Image instance (stack of images) + :param classes: class indices (refering to src). (n_img, n_nbor) + :param reflections: Bool representing whether to reflect image in `classes` + :coefs: Optional Fourier bessel coefs (avoids recomputing). + :return: Stack of Synthetic Class Average images as Image instance. """ def _cls_images(self, cls, src=None): @@ -111,39 +100,50 @@ def _cls_images(self, cls, src=None): return images -class AveragedAlign2D(Align2D): +class AligningAverager2D(Averager2D): """ - Subclass supporting aligners which perform averaging during output. + Subclass supporting averagers which perfom an aligning stage. """ + @abstractmethod def align(self, classes, reflections, basis_coefficients): """ - See Align2D.align + During this process `rotations`, `reflections`, `shifts` and + `correlations` properties will be computed for aligners. + + `rotations` is an (n_classes, n_nbor) array of angles, + which should represent the rotations needed to align images within + that class. `rotations` is measured in Radians. + + `correlations` is an (n_classes, n_nbor) array representing + a correlation like measure between classified images and their base + image (image index 0). + + `shifts` is None or an (n_classes, n_nbor) array of 2D shifts + which should represent the translation needed to best align the images + within that class. + + Subclasses of should implement and extend this method. + + :param classes: (n_classes, n_nbor) integer array of img indices + :param reflections: (n_classes, n_nbor) bool array of corresponding reflections + :param basis_coefficients: (n_img, self.pca_basis.count) compressed basis coefficients + + :returns: (reflections, rotations, shifts) """ - # Correlations are currently unused, but left for future extensions. - cls, ref, rot, shf, corrs = self._align( - classes, reflections, basis_coefficients - ) - return self.average(cls, ref, rot, shf), cls, ref, rot, shf, corrs def average( self, classes, reflections, - rotations, - shifts=None, coefs=None, ): """ - Combines images using averaging in `self.composite_basis`. - - :param classes: class indices (refering to src). (n_img, n_nbor) - :param reflections: Bool representing whether to reflect image in `classes` - :param rotations: Array of in-plane rotation angles (Radians) of image in `classes` - :param shifts: Optional array of shifts for image in `classes`. - :coefs: Optional Fourier bessel coefs (avoids recomputing). - :return: Stack of Synthetic Class Average images as Image instance. + This subclass assumes we get alignment details from `align` method. Otherwise. see Averager2D.average """ + + rotations, shifts, _ = self.align(classes, reflections, coefs) + n_classes, n_nbor = classes.shape b_avgs = np.empty((n_classes, self.composite_basis.count), dtype=self.src.dtype) @@ -181,7 +181,7 @@ def average( return ArrayImageSource(self.composite_basis.evaluate(b_avgs)) -class BFRAlign2D(AveragedAlign2D): +class BFRAverager2D(AligningAverager2D): """ This perfoms a Brute Force Rotational alignment. @@ -210,10 +210,10 @@ def __init__( if not hasattr(self.alignment_basis, "rotate"): raise RuntimeError( - f"BFRAlign2D's alignment_basis {self.alignment_basis} must provide a `rotate` method." + f"BFRAverager2D's alignment_basis {self.alignment_basis} must provide a `rotate` method." ) - def _align(self, classes, reflections, basis_coefficients): + def align(self, classes, reflections, basis_coefficients): """ Performs the actual rotational alignment estimation, returning parameters needed for averaging. @@ -258,10 +258,10 @@ def _align(self, classes, reflections, basis_coefficients): for j in range(n_nbor): correlations[k, j] = results[j, angle_idx[j]] - return classes, reflections, rotations, None, correlations + return rotations, None, correlations -class BFSRAlign2D(BFRAlign2D): +class BFSRAverager2D(BFRAverager2D): """ This perfoms a Brute Force Shift and Rotational alignment. It is potentially expensive to brute force this search space. @@ -289,7 +289,7 @@ def __init__( Example: n_x_shifts=1, n_y_shifts=0 would test {-1,0,1} X {0}. - n_x_shifts=n_y_shifts=0 is the same as calling BFRAlign2D. + n_x_shifts=n_y_shifts=0 is the same as calling BFRAverager2D. :params alignment_basis: Basis providing a `shift` and `rotate` method. :params n_angles: Number of brute force rotations to attempt, defaults 359. @@ -310,15 +310,15 @@ def __init__( if not hasattr(self.alignment_basis, "shift"): raise RuntimeError( - f"BFSRAlign2D's alignment_basis {self.alignment_basis} must provide a `shift` method." + f"BFSRAverager2D's alignment_basis {self.alignment_basis} must provide a `shift` method." ) - # Each shift will require calling the parent BFRAlign2D._align - self._bfr_align = super()._align + # Each shift will require calling the parent BFRAverager2D.align + self._bfr_align = super().align - def _align(self, classes, reflections, basis_coefficients): + def align(self, classes, reflections, basis_coefficients): """ - See `Align2D.align` + See `AligningAverager2D.align` """ # Admit simple case of single case alignment @@ -360,7 +360,7 @@ def _align(self, classes, reflections, basis_coefficients): original_coef, -shift ) - _, _, _rotations, _, _correlations = self._bfr_align( + _rotations, _, _correlations = self._bfr_align( classes, reflections, basis_coefficients ) @@ -384,10 +384,10 @@ def _align(self, classes, reflections, basis_coefficients): f"Shift ({x},{y}) complete. Improved {np.sum(improved_indices)} alignments." ) - return classes, reflections, rotations, shifts, correlations + return rotations, shifts, correlations -class ReddyChatterjiAlign2D(AveragedAlign2D): +class ReddyChatterjiAverager2D(AligningAverager2D): """ Attempts rotational estimation using Reddy Chatterji log polar Fourier cross correlation. Then attempts shift (translational) estimation using cross correlation. @@ -417,7 +417,7 @@ def __init__( ): """ :param alignment_basis: Basis to be used during alignment. - For current implementation of ReddyChatterjiAlign2D this should be `None`. + For current implementation of ReddyChatterjiAverager2D this should be `None`. Instead see `alignment_source`. :param source: Source of original images. :param composite_basis: Basis to be used during class average composition. @@ -482,7 +482,7 @@ def _phase_cross_correlation(self, img0, img1): return np.abs(cross_correlation), shifts - def _align(self, classes, reflections, basis_coefficients): + def align(self, classes, reflections, basis_coefficients): """ Performs the actual rotational alignment estimation, returning parameters needed for averaging. @@ -507,7 +507,7 @@ def _align(self, classes, reflections, basis_coefficients): images, classes[k], reflections[k] ) - return classes, reflections, rotations, shifts, correlations + return rotations, shifts, correlations def _reddychatterji(self, images, class_k, reflection_k): """ @@ -698,14 +698,15 @@ def average( self, classes, reflections, - rotations, - shifts=None, coefs=None, ): """ This averages classes performing rotations then shifts. - Otherwise is similar to `AveragedAlign2D.average`. + Otherwise is similar to `AligningAverager2D.average`. """ + + rotations, shifts, _ = self.align(classes, reflections, coefs) + n_classes, n_nbor = classes.shape b_avgs = np.empty((n_classes, self.composite_basis.count), dtype=self.src.dtype) @@ -870,7 +871,7 @@ def _averaged_diagnostic(self, ia, a, ib, b, sb, rb): plt.show() -class BFSReddyChatterjiAlign2D(ReddyChatterjiAlign2D): +class BFSReddyChatterjiAverager2D(ReddyChatterjiAverager2D): """ Brute Force Shifts (Translations) - ReddyChatterji (Log-Polar) Rotations @@ -899,7 +900,7 @@ def __init__( ): """ :param alignment_basis: Basis to be used during alignment. - For current implementation of ReddyChatterjiAlign2D this should be `None`. + For current implementation of ReddyChatterjiAverager2D this should be `None`. Instead see `alignment_source`. :param source: Source of original images. :param composite_basis: Basis to be used during class average composition. @@ -928,7 +929,7 @@ def __init__( # Assign search radius self.radius = radius or source.L // 8 - def _align(self, classes, reflections, basis_coefficients): + def align(self, classes, reflections, basis_coefficients): """ Performs the actual rotational alignment estimation, returning parameters needed for averaging. @@ -979,34 +980,30 @@ def _align(self, classes, reflections, basis_coefficients): shifts = np.where(shifts, _shifts, shifts) logger.debug(f"Shift {s} has improved {np.sum(improved)} results") - return classes, reflections, rotations, shifts, correlations + return rotations, shifts, correlations def average( self, classes, reflections, - rotations, - shifts=None, coefs=None, ): """ - See AveragedAlign2D.average. + See Averager2D.average. """ - # ReddyChatterjiAlign2D does rotations then shifts. + # ReddyChatterjiAverager2D does rotations then shifts. # For brute force, we'd like shifts then rotations, - # as is done in gerneral via AveragedAlign2D. - return AveragedAlign2D.average( - self, classes, reflections, rotations, shifts, coefs - ) + # as is done in general in AligningAverager2D + return Averager2D.average(self, classes, reflections, coefs) -class EMAlign2D(Align2D): +class EMAverager2D(Averager2D): """ Citation needed. """ -class FTKAlign2D(Align2D): +class FTKAverager2D(Averager2D): """ Factorization of the translation kernel for fast rigid image alignment. Rangan, A.V., Spivak, M., Anden, J., & Barnett, A.H. (2019). diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index 13ba624eab..7772dd9f43 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -6,7 +6,7 @@ from tqdm import tqdm from aspire.basis import FSPCABasis -from aspire.classification import BFSReddyChatterjiAlign2D, Class2D +from aspire.classification import BFSReddyChatterjiAverager2D, Class2D from aspire.classification.legacy_implementations import bispec_2drot_large, pca_y from aspire.numeric import ComplexPCA from aspire.utils.random import rand @@ -29,7 +29,7 @@ def __init__( large_pca_implementation="legacy", nn_implementation="legacy", bispectrum_implementation="legacy", - aligner=None, + averager=None, dtype=None, seed=None, ): @@ -47,7 +47,7 @@ def __init__( for Viewing Direction Classification in Cryo-EM. (2014) :param src: Source instance. Note it is possible to use one `source` for classification (ie CWF), - and a different `source` for stacking in the `aligner`. + and a different `source` for stacking in the `averager`. :param pca_basis: Optional FSPCA Basis instance :param fspca_components: Components (top eigvals) to keep from full FSCPA, default truncates to 400. :param alpha: Amplitude Power Scale, default 1/3 (eq 20 from RIIR paper). @@ -59,7 +59,7 @@ def __init__( :param large_pca_implementation: See `pca`. :param nn_implementation: See `nn_classification`. :param bispectrum_implementation: See `bispectrum`. - :param aligner: An Align2D subclass. Defaults to BFSReddyChatterjiAlign2D. + :param averager: An Averager2D subclass. Defaults to BFSReddyChatterjiAverager2D. :param dtype: Optional dtype, otherwise taken from src. :param seed: Optional RNG seed to be passed to random methods, (example Random NN). :return: RIRClass2D instance to be used to compute bispectrum-like rotationally invariant 2D classification. @@ -100,7 +100,7 @@ def __init__( self.alpha = alpha self.bispectrum_components = bispectrum_components self.bispectrum_freq_cutoff = bispectrum_freq_cutoff - self.aligner = aligner + self.averager = averager if self.src.n < self.bispectrum_components: raise RuntimeError( @@ -169,10 +169,10 @@ def classify(self, diagnostics=False): # For convenience, assign the fb_basis used in the pca_basis. self.fb_basis = self.pca_basis.basis - # When not provided by a user, the aligner is instantiated after + # When not provided by a user, the averager is instantiated after # we are certain our pca_basis has been constructed. - if self.aligner is None: - self.aligner = BFSReddyChatterjiAlign2D( + if self.averager is None: + self.averager = BFSReddyChatterjiAverager2D( None, self.src, self.fb_basis, dtype=self.dtype ) @@ -208,16 +208,16 @@ def averages(self, classes, reflections, distances): logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") classes, reflections = self.select_classes(classes, reflections) - # # Stage 4: Align + # # Stage 4: Averager logger.info( - f"Begin Rotational Alignment of {classes.shape[0]} Classes using {self.aligner}." + f"Begin Averaging of {classes.shape[0]} Classes using {self.averager}." ) - return self.aligner.align(classes, reflections, self.fspca_coef) + return self.averager.average(classes, reflections, self.fspca_coef) def select_classes(self, classes, reflections): """ - Select the `n_classes` to align from the (n_images) population of classes. + Select the `n_classes` to average from the (n_images) population of classes. """ # Generate indices for random sample (can do something smarter, or build this out later). # For testing/poc just take the first n_classes so it matches earlier plots for manual comparison @@ -315,7 +315,7 @@ def _sk_nn_classification(self, coeff_b, coeff_b_r): def _legacy_nn_classification(self, coeff_b, coeff_b_r, batch_size=2000): """ - Perform nearest neighbor classification and alignment. + Perform nearest neighbor classification. """ # Note kept ordering from legacy code (n_features, n_img) diff --git a/tests/test_align2d.py b/tests/test_align2d.py index 612e76a381..eb356e69b1 100644 --- a/tests/test_align2d.py +++ b/tests/test_align2d.py @@ -6,7 +6,7 @@ import pytest from aspire.basis import DiracBasis, FFBBasis2D -from aspire.classification import AveragedAlign2D, BFRAlign2D, BFSRAlign2D +from aspire.classification import Averager2D, BFRAverager2D, BFSRAverager2D from aspire.source import Simulation from aspire.utils import Rotation from aspire.volume import Volume @@ -19,9 +19,9 @@ # Ignore Gimbal lock warning for our in plane rotations. @pytest.mark.filterwarnings("ignore:Gimbal lock detected") -class Align2DTestCase(TestCase): - # Subclasses should override `aligner` with a different class. - aligner = AveragedAlign2D +class Averager2DTestCase(TestCase): + # Subclasses should override `averager` with a different class. + averager = Averager2D def setUp(self): @@ -33,7 +33,7 @@ def setUp(self): self.n_img = 3 self.dtype = np.float64 - # Create a Basis to use in alignment. + # Create a Basis to use in averager. self.basis = FFBBasis2D((self.resolution, self.resolution), dtype=self.dtype) # This sets up a trivial class, where there is one group having all images. @@ -49,20 +49,23 @@ def tearDown(self): pass def _getSrc(self): - # Base Align2D does not require anything from source. + # Base Averager2D does not require anything from source. # Subclasses implement specific src return None def testTypeMismatch(self): - # Intentionally mismatch Basis and Aligner dtypes + # Work around ABC, which won't let us test the unimplemented base case. + self.averager.__abstractmethods__ = set() + + # Intentionally mismatch Basis and Averager dtypes if self.dtype == np.float32: test_dtype = np.float64 else: test_dtype = np.float32 with self._caplog.at_level(logging.WARN): - self.aligner(self.basis, self._getSrc(), dtype=test_dtype) + self.averager(self.basis, self._getSrc(), dtype=test_dtype) assert "does not match dtype" in self._caplog.text def _construct_rotations(self): @@ -97,9 +100,9 @@ def r(theta): @pytest.mark.filterwarnings("ignore:Gimbal lock detected") -class BFRAlign2DTestCase(Align2DTestCase): +class BFRAverager2DTestCase(Averager2DTestCase): - aligner = BFRAlign2D + averager = BFRAverager2D def setUp(self): @@ -141,23 +144,19 @@ def testNoRot(self): # and that should raise an error during instantiation. with pytest.raises(RuntimeError, match=r".* must provide a `rotate` method."): - _ = self.aligner(basis, self._getSrc()) + _ = self.averager(basis, self._getSrc()) - def testAlign(self): + def testAverager(self): """ Construct a stack of images with known rotations. - Rotationally align the stack and compare output with known rotations. + Rotationally averager the stack and compare output with known rotations. """ - # Construction the Aligner and then call the main `align` method - algnr = self.aligner(self.basis, self._getSrc(), n_angles=self.n_search_angles) - _, _classes, _reflections, _rotations, _shifts, _ = algnr.align( - self.classes, self.reflections, self.coefs - ) + # Construction the Averager and then call the `align` method + avgr = self.averager(self.basis, self._getSrc(), n_angles=self.n_search_angles) + _rotations, _shifts, _ = avgr.align(self.classes, self.reflections, self.coefs) - self.assertTrue(np.all(_classes == self.classes)) - self.assertTrue(np.all(_reflections == self.reflections)) self.assertIsNone(_shifts) # Crude check that we are closer to known angle than the next rotation @@ -170,20 +169,20 @@ def testAlign(self): @pytest.mark.filterwarnings("ignore:Gimbal lock detected") -class BFSRAlign2DTestCase(BFRAlign2DTestCase): +class BFSRAverager2DTestCase(BFRAverager2DTestCase): - aligner = BFSRAlign2D + averager = BFSRAverager2D def setUp(self): # Inherit basic params from the base class - super(BFRAlign2DTestCase, self).setUp() + super(BFRAverager2DTestCase, self).setUp() # Setup shifts, don't shift the base image self.shifts = np.zeros((self.n_img, 2)) self.shifts[1:, 0] = 2 self.shifts[1:, 1] = 4 - # Execute the remaining setup from BFRAlign2DTestCase + # Execute the remaining setup from BFRAverager2DTestCase super().setUp() def testNoShift(self): @@ -200,29 +199,24 @@ def testNoShift(self): # and that should raise an error during instantiation. with pytest.raises(RuntimeError, match=r".* must provide a `shift` method."): - _ = self.aligner(basis, self._getSrc()) + _ = self.averager(basis, self._getSrc()) - def testAlign(self): + def testAverager(self): """ Construct a stack of images with known rotations. - Rotationally align the stack and compare output with known rotations. + Rotationally averager the stack and compare output with known rotations. """ - # Construction the Aligner and then call the main `align` method - algnr = self.aligner( + # Construction the Averager and then call the main `align` method + avgr = self.averager( self.basis, self._getSrc(), n_angles=self.n_search_angles, n_x_shifts=1, n_y_shifts=1, ) - _, _classes, _reflections, _rotations, _shifts, _ = algnr.align( - self.classes, self.reflections, self.coefs - ) - - self.assertTrue(np.all(_classes == self.classes)) - self.assertTrue(np.all(_reflections == self.reflections)) + _rotations, _shifts, _ = avgr.align(self.classes, self.reflections, self.coefs) # Crude check that we are closer to known angle than the next rotation self.assertTrue(np.all((_rotations - self.thetas) <= (self.step / 2))) diff --git a/tests/test_class2D.py b/tests/test_class2D.py index 69cd20a38b..41da61017d 100644 --- a/tests/test_class2D.py +++ b/tests/test_class2D.py @@ -7,7 +7,7 @@ from sklearn import datasets from aspire.basis import FFBBasis2D, FSPCABasis -from aspire.classification import BFSRAlign2D, RIRClass2D +from aspire.classification import BFRAverager2D, RIRClass2D from aspire.classification.legacy_implementations import bispec_2drot_large, pca_y from aspire.operators import ScalarFilter from aspire.source import Simulation @@ -218,12 +218,11 @@ def testRIRsk(self): large_pca_implementation="sklearn", nn_implementation="sklearn", bispectrum_implementation="devel", - aligner=BFSRAlign2D( + averager=BFRAverager2D( self.noisy_fspca_basis, self.noisy_src, self.basis, n_angles=100, - n_x_shifts=0, ), ) From 59ff967a3d0c7811bc74844358bf6c1bd000323a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 14 Feb 2022 14:30:57 -0500 Subject: [PATCH 29/40] Update tutorial after Refactor aligner~>averager --- gallery/tutorials/class_averaging.py | 16 +--------------- src/aspire/basis/steerable.py | 2 +- src/aspire/classification/averager2d.py | 2 +- src/aspire/classification/rir_class2d.py | 2 +- 4 files changed, 4 insertions(+), 18 deletions(-) diff --git a/gallery/tutorials/class_averaging.py b/gallery/tutorials/class_averaging.py index 5bebccbe5d..6d449a24d2 100644 --- a/gallery/tutorials/class_averaging.py +++ b/gallery/tutorials/class_averaging.py @@ -170,9 +170,7 @@ ) classes, reflections, dists = noisy_rir.classify() -avgs, classes, reflections, rotations, shifts, corrs = noisy_rir.averages( - classes, reflections, dists -) +avgs = noisy_rir.averages(classes, reflections, dists) # %% # Display Classes @@ -198,17 +196,5 @@ # Report the identified neighbors Image(noisy_src.images(0, np.inf)[classes[review_class]]).show() -# Report their associated rots_refls -rots_refls = ["index, Rotation, Reflection"] -for i in range(classes.shape[1]): - rots_refls.append( - f"{i}, {rotations[review_class, i] * 180 / np.pi}, {reflections[review_class, i]}" - ) -rots_refls = "\n".join(rots_refls) - -logger.info( - f"Class {review_class}'s estimated Rotations and Reflections:\n{rots_refls}" -) - # Display the averaged result avgs.images(review_class, 1).show() diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 85f0527289..5297351979 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -188,4 +188,4 @@ def shift(self, coef, shifts): f" received {shifts.shape}." ) - return self.evaluate_t(self.evaluate(coef).shift(shifts)) + return self.expand(self.evaluate(coef).shift(shifts)) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 4f6ea862cd..6c2161cf46 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -994,7 +994,7 @@ def average( # ReddyChatterjiAverager2D does rotations then shifts. # For brute force, we'd like shifts then rotations, # as is done in general in AligningAverager2D - return Averager2D.average(self, classes, reflections, coefs) + return AligningAverager2D.average(self, classes, reflections, coefs) class EMAverager2D(Averager2D): diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index 7772dd9f43..bbb7200d08 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -213,7 +213,7 @@ def averages(self, classes, reflections, distances): f"Begin Averaging of {classes.shape[0]} Classes using {self.averager}." ) - return self.averager.average(classes, reflections, self.fspca_coef) + return self.averager.average(classes, reflections) def select_classes(self, classes, reflections): """ From 91bc6872804fe72af5d33826cf394a80ca1762d9 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 14 Feb 2022 17:02:44 -0500 Subject: [PATCH 30/40] Update tests after Refactor aligner~>averager --- .../experimental_abinitio_pipeline.py | 4 +- .../simulated_abinitio_pipeline.py | 4 +- src/aspire/classification/averager2d.py | 126 ++++++++++-------- src/aspire/classification/rir_class2d.py | 2 +- tests/{test_align2d.py => test_averager2d.py} | 54 +++++++- tests/test_class2D.py | 3 +- 6 files changed, 123 insertions(+), 70 deletions(-) rename tests/{test_align2d.py => test_averager2d.py} (80%) diff --git a/gallery/experiments/experimental_abinitio_pipeline.py b/gallery/experiments/experimental_abinitio_pipeline.py index 7e57fdbef8..1851ce1ca4 100644 --- a/gallery/experiments/experimental_abinitio_pipeline.py +++ b/gallery/experiments/experimental_abinitio_pipeline.py @@ -121,9 +121,7 @@ # Use regular `src` for the alignment and composition (averaging). composite_basis = FFBBasis2D((src.L,) * 2, dtype=src.dtype) - custom_averager = BFSReddyChatterjiAverager2D( - None, src, composite_basis, dtype=src.dtype - ) + custom_averager = BFSReddyChatterjiAverager2D(composite_basis, src, dtype=src.dtype) # %% diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index a93807687b..d9906fedaa 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -158,9 +158,7 @@ def noise_function(x, y): # Use regular `src` for the alignment and composition (averaging). composite_basis = FFBBasis2D((src.L,) * 2, dtype=src.dtype) - custom_averager = BFSReddyChatterjiAverager2D( - None, src, composite_basis, dtype=src.dtype - ) + custom_averager = BFSReddyChatterjiAverager2D(composite_basis, src, dtype=src.dtype) # %% diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 6c2161cf46..42eb715ede 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -21,19 +21,14 @@ class Averager2D(ABC): Base class for 2D Image Averaging methods. """ - def __init__( - self, alignment_basis, source, composite_basis=None, batch_size=512, dtype=None - ): + def __init__(self, composite_basis, source, batch_size=512, dtype=None): """ - :param alignment_basis: Basis to be used during alignment (eg FSPCA) - :param source: Source of original images. :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D) + :param source: Source of original images. :param dtype: Numpy dtype to be used during alignment. """ - self.alignment_basis = alignment_basis - # if composite_basis is None, use alignment_basis - self.composite_basis = composite_basis or self.alignment_basis + self.composite_basis = composite_basis self.src = source self.batch_size = batch_size if dtype is None: @@ -74,7 +69,7 @@ def average( :param classes: class indices (refering to src). (n_img, n_nbor) :param reflections: Bool representing whether to reflect image in `classes` - :coefs: Optional Fourier bessel coefs (avoids recomputing). + :coefs: Optional basis coefs (could avoid recomputing). :return: Stack of Synthetic Class Average images as Image instance. """ @@ -105,6 +100,34 @@ class AligningAverager2D(Averager2D): Subclass supporting averagers which perfom an aligning stage. """ + def __init__( + self, composite_basis, source, alignment_basis=None, batch_size=512, dtype=None + ): + """ + :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D) + :param source: Source of original images. + :param alignment_basis: Optional, basis to be used during alignment (eg FSPCA) + :param dtype: Numpy dtype to be used during alignment. + """ + + super().__init__( + composite_basis=composite_basis, + source=source, + batch_size=batch_size, + dtype=dtype, + ) + # If alignment_basis is None, use composite_basis + self.alignment_basis = alignment_basis or self.composite_basis + + if not hasattr(self.alignment_basis, "rotate"): + raise RuntimeError( + f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `rotate` method." + ) + if not hasattr(self.alignment_basis, "shift"): + raise RuntimeError( + f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `shift` method." + ) + @abstractmethod def align(self, classes, reflections, basis_coefficients): """ @@ -129,7 +152,7 @@ def align(self, classes, reflections, basis_coefficients): :param reflections: (n_classes, n_nbor) bool array of corresponding reflections :param basis_coefficients: (n_img, self.pca_basis.count) compressed basis coefficients - :returns: (reflections, rotations, shifts) + :returns: (rotations, shifts, correlations) """ def average( @@ -159,7 +182,7 @@ def average( if shifts is not None: neighbors_imgs.shift(shifts[i]) - neighbors_coefs = self.composite_basis.expand(neighbors_imgs) + neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) else: # Get the neighbors neighbors_ids = classes[i] @@ -192,27 +215,22 @@ class BFRAverager2D(AligningAverager2D): def __init__( self, - alignment_basis, + composite_basis, source, - composite_basis=None, + alignment_basis=None, n_angles=359, batch_size=512, dtype=None, ): """ - :params alignment_basis: Basis providing a `rotate` method. - :param source: Source of original images. + See AligningAverager2D, adds: + :params n_angles: Number of brute force rotations to attempt, defaults 359. """ - super().__init__(alignment_basis, source, composite_basis, batch_size, dtype) + super().__init__(composite_basis, source, alignment_basis, batch_size, dtype) self.n_angles = n_angles - if not hasattr(self.alignment_basis, "rotate"): - raise RuntimeError( - f"BFRAverager2D's alignment_basis {self.alignment_basis} must provide a `rotate` method." - ) - def align(self, classes, reflections, basis_coefficients): """ Performs the actual rotational alignment estimation, @@ -236,7 +254,13 @@ def align(self, classes, reflections, basis_coefficients): for k in trange(n_classes): # Get the coefs for these neighbors - nbr_coef = basis_coefficients[classes[k]] + if basis_coefficients is None: + # Retrieve relavent images + neighbors_imgs = Image(self._cls_images(classes[k])) + # Evaluate_T into basis + nbr_coef = self.composite_basis.evaluate_t(neighbors_imgs) + else: + nbr_coef = basis_coefficients[classes[k]] for i, angle in enumerate(test_angles): # Rotate the set of neighbors by angle, @@ -274,9 +298,9 @@ class BFSRAverager2D(BFRAverager2D): def __init__( self, - alignment_basis, + composite_basis, source, - composite_basis=None, + alignment_basis=None, n_angles=359, n_x_shifts=1, n_y_shifts=1, @@ -284,22 +308,23 @@ def __init__( dtype=None, ): """ - Note that n_x_shifts and n_y_shifts are the number of shifts to perform - in each direction. + See AligningAverager2D and BFRAverager2D, adds: `n_x_shifts`, `n_y_shifts`. + + Note that `n_x_shifts` and `n_y_shifts` are the number of shifts + to perform in each direction. Example: n_x_shifts=1, n_y_shifts=0 would test {-1,0,1} X {0}. n_x_shifts=n_y_shifts=0 is the same as calling BFRAverager2D. - :params alignment_basis: Basis providing a `shift` and `rotate` method. :params n_angles: Number of brute force rotations to attempt, defaults 359. :params n_x_shifts: +- Number of brute force xshifts to attempt, defaults 1. :params n_y_shifts: +- Number of brute force xshifts to attempt, defaults 1. """ super().__init__( - alignment_basis, - source, composite_basis, + source, + alignment_basis, n_angles, batch_size=batch_size, dtype=dtype, @@ -308,11 +333,6 @@ def __init__( self.n_x_shifts = n_x_shifts self.n_y_shifts = n_y_shifts - if not hasattr(self.alignment_basis, "shift"): - raise RuntimeError( - f"BFSRAverager2D's alignment_basis {self.alignment_basis} must provide a `shift` method." - ) - # Each shift will require calling the parent BFRAverager2D.align self._bfr_align = super().align @@ -343,6 +363,10 @@ def align(self, classes, reflections, basis_coefficients): correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf shifts = np.empty((*classes.shape, 2), dtype=int) + if basis_coefficients is None: + # Retrieve image coefficients, this is bad, but should be deleted anyway. + basis_coefficients = self.composite_basis.evaluate_t(self.src.images(0, np.inf)) + # We want to maintain the original coefs for the base images, # because we will mutate them with shifts in the loop. original_coef = basis_coefficients[classes[:, 0], :] @@ -407,20 +431,16 @@ class ReddyChatterjiAverager2D(AligningAverager2D): def __init__( self, - alignment_basis, + composite_basis, source, - composite_basis=None, alignment_source=None, diagnostics=False, batch_size=512, dtype=None, ): """ - :param alignment_basis: Basis to be used during alignment. - For current implementation of ReddyChatterjiAverager2D this should be `None`. - Instead see `alignment_source`. - :param source: Source of original images. :param composite_basis: Basis to be used during class average composition. + :param source: Source of original images. :param alignment_source: Optional, source to be used during class average alignment. Must be the same resolution as `source`. :param dtype: Numpy dtype to be used during alignment. @@ -440,15 +460,8 @@ def __init__( "Currently `alignment_src.dtype` must equal `source.dtype`" ) - # Sanity check. This API should be rethought once all basis and - # alignment methods have been incorporated. - assert alignment_basis is None # We use sources directly for alignment - assert ( - composite_basis is not None - ) # However, we require a basis for rotating etc. - super().__init__( - alignment_basis, source, composite_basis, batch_size=batch_size, dtype=dtype + composite_basis, source, composite_basis, batch_size=batch_size, dtype=dtype ) def _phase_cross_correlation(self, img0, img1): @@ -503,7 +516,7 @@ def align(self, classes, reflections, basis_coefficients): # # Get the array of images for this class, using the `alignment_src`. images = self._cls_images(classes[k], src=self.alignment_src) - rotations[k], correlations[k], shifts[k] = self._reddychatterji( + rotations[k], shifts[k], correlations[k] = self._reddychatterji( images, classes[k], reflections[k] ) @@ -692,7 +705,7 @@ def _reddychatterji(self, images, class_k, reflection_k): self.__cache.pop(id(warped_fixed_img_fs), None) self.__cache.pop(id(twfixed_img), None) - return rotations_k, correlations_k, shifts_k + return rotations_k, shifts_k, correlations_k def average( self, @@ -717,7 +730,7 @@ def average( if coefs is None: # Retrieve relavent images directly from source. neighbors_imgs = Image(self._cls_images(classes[i])) - neighbors_coefs = self.composite_basis.expand(neighbors_imgs) + neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) else: # Get the neighbors neighbors_ids = classes[i] @@ -889,9 +902,8 @@ class BFSReddyChatterjiAverager2D(ReddyChatterjiAverager2D): def __init__( self, - alignment_basis, + composite_basis, source, - composite_basis=None, alignment_source=None, radius=None, diagnostics=False, @@ -915,9 +927,8 @@ def __init__( """ super().__init__( - alignment_basis, - source, composite_basis, + source, alignment_source, diagnostics, batch_size=batch_size, @@ -947,7 +958,6 @@ def align(self, classes, reflections, basis_coefficients): rotations = np.zeros(classes.shape, dtype=self.dtype) _correlations = np.zeros(classes.shape, dtype=self.dtype) correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf - _shifts = np.zeros((*classes.shape, 2), dtype=int) shifts = np.zeros((*classes.shape, 2), dtype=int) # We'll brute force all shifts in a grid. @@ -968,7 +978,7 @@ def align(self, classes, reflections, basis_coefficients): # Don't shift the base image images[1:] = Image(unshifted_images[1:]).shift(s).asnumpy() - rotations[k], correlations[k], shifts[k] = self._reddychatterji( + rotations[k], _, correlations[k] = self._reddychatterji( images, classes[k], reflections[k] ) @@ -977,7 +987,7 @@ def align(self, classes, reflections, basis_coefficients): improved = _correlations > correlations correlations = np.where(improved, _correlations, correlations) rotations = np.where(improved, _rotations, rotations) - shifts = np.where(shifts, _shifts, shifts) + shifts = np.where(improved[..., np.newaxis], s, shifts) logger.debug(f"Shift {s} has improved {np.sum(improved)} results") return rotations, shifts, correlations diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index bbb7200d08..82e9f9b6e1 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -173,7 +173,7 @@ def classify(self, diagnostics=False): # we are certain our pca_basis has been constructed. if self.averager is None: self.averager = BFSReddyChatterjiAverager2D( - None, self.src, self.fb_basis, dtype=self.dtype + self.fb_basis, self.src, dtype=self.dtype ) # Get the expanded coefs in the compressed FSPCA space. diff --git a/tests/test_align2d.py b/tests/test_averager2d.py similarity index 80% rename from tests/test_align2d.py rename to tests/test_averager2d.py index eb356e69b1..2c54d07706 100644 --- a/tests/test_align2d.py +++ b/tests/test_averager2d.py @@ -6,7 +6,13 @@ import pytest from aspire.basis import DiracBasis, FFBBasis2D -from aspire.classification import Averager2D, BFRAverager2D, BFSRAverager2D +from aspire.classification import ( + Averager2D, + BFRAverager2D, + BFSRAverager2D, + BFSReddyChatterjiAverager2D, + ReddyChatterjiAverager2D, +) from aspire.source import Simulation from aspire.utils import Rotation from aspire.volume import Volume @@ -153,7 +159,7 @@ def testAverager(self): Rotationally averager the stack and compare output with known rotations. """ - # Construction the Averager and then call the `align` method + # Construct the Averager and then call the `align` method avgr = self.averager(self.basis, self._getSrc(), n_angles=self.n_search_angles) _rotations, _shifts, _ = avgr.align(self.classes, self.reflections, self.coefs) @@ -208,7 +214,7 @@ def testAverager(self): Rotationally averager the stack and compare output with known rotations. """ - # Construction the Averager and then call the main `align` method + # Construct the Averager and then call the main `align` method avgr = self.averager( self.basis, self._getSrc(), @@ -234,3 +240,45 @@ def testAverager(self): # non zero shift+rot improved corr. # Perhaps in the future should check more details. self.assertTrue(np.all(np.hypot(*_shifts[0][1:].T) >= 1)) + + +@pytest.mark.filterwarnings("ignore:Gimbal lock detected") +class ReddyChatterjiAverager2DTestCase(BFSRAverager2DTestCase): + + averager = ReddyChatterjiAverager2D + + def testAverager(self): + """ + Construct a stack of images with known rotations. + + Rotationally averager the stack and compare output with known rotations. + """ + + # Construct the Averager and then call the main `align` method + avgr = self.averager( + composite_basis=self.basis, + source=self._getSrc(), + dtype=self.dtype, + ) + _rotations, _shifts, _ = avgr.align(self.classes, self.reflections, self.coefs) + + # Crude check that we are closer to known angle than the next rotation + self.assertTrue(np.all((_rotations - self.thetas) <= (self.step / 2))) + + # Fine check that we are within one degree. + self.assertTrue(np.all((_rotations - self.thetas) <= (2 * np.pi / 360.0))) + + # Check that we are _not_ shifting the base image + self.assertTrue(np.all(_shifts[0][0] == 0)) + # Check that we produced estimated shifts away from origin + # Note that Simulation's rot+shift is generally not equal to shift+rot. + # Instead we check that some combination of + # non zero shift+rot improved corr. + # Perhaps in the future should check more details. + self.assertTrue(np.all(np.hypot(*_shifts[0][1:].T) >= 1)) + + +@pytest.mark.filterwarnings("ignore:Gimbal lock detected") +class BFSReddyChatterjiAverager2DTestCase(ReddyChatterjiAverager2DTestCase): + + averager = BFSReddyChatterjiAverager2D diff --git a/tests/test_class2D.py b/tests/test_class2D.py index 41da61017d..4e79d531da 100644 --- a/tests/test_class2D.py +++ b/tests/test_class2D.py @@ -219,9 +219,8 @@ def testRIRsk(self): nn_implementation="sklearn", bispectrum_implementation="devel", averager=BFRAverager2D( - self.noisy_fspca_basis, + self.noisy_fspca_basis.basis, # FFB basis self.noisy_src, - self.basis, n_angles=100, ), ) From 1b0c7250eaa211c29a75372234f7a863fb786cc2 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 14 Feb 2022 17:07:11 -0500 Subject: [PATCH 31/40] tox --- src/aspire/classification/averager2d.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 42eb715ede..a2927ddba8 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -365,7 +365,9 @@ def align(self, classes, reflections, basis_coefficients): if basis_coefficients is None: # Retrieve image coefficients, this is bad, but should be deleted anyway. - basis_coefficients = self.composite_basis.evaluate_t(self.src.images(0, np.inf)) + basis_coefficients = self.composite_basis.evaluate_t( + self.src.images(0, np.inf) + ) # We want to maintain the original coefs for the base images, # because we will mutate them with shifts in the loop. From d166099086acb520c52c56ebaec7fc7936e45c7e Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 15 Feb 2022 08:43:15 -0500 Subject: [PATCH 32/40] minor cleanup, strings, init 0 --- src/aspire/classification/averager2d.py | 20 ++++++++------------ tests/test_class2D.py | 1 + 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index a2927ddba8..854552f376 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -23,7 +23,7 @@ class Averager2D(ABC): def __init__(self, composite_basis, source, batch_size=512, dtype=None): """ - :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D) + :param composite_basis: Basis to be used during class average composition (eg FFB2D) :param source: Source of original images. :param dtype: Numpy dtype to be used during alignment. """ @@ -67,8 +67,8 @@ def average( Should return an Image source of synthetic class averages. - :param classes: class indices (refering to src). (n_img, n_nbor) - :param reflections: Bool representing whether to reflect image in `classes` + :param classes: class indices (refering to src). (n_img, n_nbor). + :param reflections: Bool representing whether to reflect image in `classes`. :coefs: Optional basis coefs (could avoid recomputing). :return: Stack of Synthetic Class Average images as Image instance. """ @@ -106,7 +106,7 @@ def __init__( """ :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D) :param source: Source of original images. - :param alignment_basis: Optional, basis to be used during alignment (eg FSPCA) + :param alignment_basis: Optional, basis to be used only during alignment (eg FSPCA) :param dtype: Numpy dtype to be used during alignment. """ @@ -541,13 +541,9 @@ def _reddychatterji(self, images, class_k, reflection_k): # Result arrays M = len(images) - rotations_k = np.empty(M, dtype=self.dtype) - correlations_k = np.empty(M, dtype=self.dtype) - shifts_k = np.empty((M, 2), dtype=self.dtype) - # Initialize for Image 0, others will populate in loop. - rotations_k[0] = 0 - correlations_k[0] = 0 - shifts_k[0] = 0 + rotations_k = np.zeros(M, dtype=self.dtype) + correlations_k = np.zeros(M, dtype=self.dtype) + shifts_k = np.zeros((M, 2), dtype=int) # De-Mean, note images is mutated and should be a `copy`. images -= images.mean(axis=(-1, -2))[:, np.newaxis, np.newaxis] @@ -964,7 +960,7 @@ def align(self, classes, reflections, basis_coefficients): # We'll brute force all shifts in a grid. g = grid_2d(L, normalized=False) - disc = g["r"] <= L // 8 # make param later + disc = g["r"] <= self.radius X, Y = g["x"][disc], g["y"][disc] for k in trange(n_classes): diff --git a/tests/test_class2D.py b/tests/test_class2D.py index 4e79d531da..07935f217b 100644 --- a/tests/test_class2D.py +++ b/tests/test_class2D.py @@ -40,6 +40,7 @@ def setUp(self): vols=v, dtype=self.dtype, ) + self.src.cache() # Precompute image stack # Calculate some projection images self.imgs = self.src.images(0, self.src.n) From 478dff4a29336393108ee5433f805800c39f1f7a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 15 Feb 2022 11:01:06 -0500 Subject: [PATCH 33/40] revert another problematic expand --- src/aspire/basis/steerable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 5297351979..85f0527289 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -188,4 +188,4 @@ def shift(self, coef, shifts): f" received {shifts.shape}." ) - return self.expand(self.evaluate(coef).shift(shifts)) + return self.evaluate_t(self.evaluate(coef).shift(shifts)) From 53b826e85896c329e4f5ffc56a0cb5896c1e9d83 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 18 Feb 2022 07:38:28 -0500 Subject: [PATCH 34/40] Remove unused `batch_size` --- src/aspire/classification/averager2d.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 854552f376..1510189aa7 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -21,7 +21,7 @@ class Averager2D(ABC): Base class for 2D Image Averaging methods. """ - def __init__(self, composite_basis, source, batch_size=512, dtype=None): + def __init__(self, composite_basis, source, dtype=None): """ :param composite_basis: Basis to be used during class average composition (eg FFB2D) :param source: Source of original images. @@ -30,7 +30,6 @@ def __init__(self, composite_basis, source, batch_size=512, dtype=None): self.composite_basis = composite_basis self.src = source - self.batch_size = batch_size if dtype is None: if self.composite_basis: self.dtype = self.composite_basis.dtype @@ -100,9 +99,7 @@ class AligningAverager2D(Averager2D): Subclass supporting averagers which perfom an aligning stage. """ - def __init__( - self, composite_basis, source, alignment_basis=None, batch_size=512, dtype=None - ): + def __init__(self, composite_basis, source, alignment_basis=None, dtype=None): """ :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D) :param source: Source of original images. @@ -113,7 +110,6 @@ def __init__( super().__init__( composite_basis=composite_basis, source=source, - batch_size=batch_size, dtype=dtype, ) # If alignment_basis is None, use composite_basis @@ -219,7 +215,6 @@ def __init__( source, alignment_basis=None, n_angles=359, - batch_size=512, dtype=None, ): """ @@ -227,7 +222,7 @@ def __init__( :params n_angles: Number of brute force rotations to attempt, defaults 359. """ - super().__init__(composite_basis, source, alignment_basis, batch_size, dtype) + super().__init__(composite_basis, source, alignment_basis, dtype) self.n_angles = n_angles @@ -304,7 +299,6 @@ def __init__( n_angles=359, n_x_shifts=1, n_y_shifts=1, - batch_size=512, dtype=None, ): """ @@ -326,7 +320,6 @@ def __init__( source, alignment_basis, n_angles, - batch_size=batch_size, dtype=dtype, ) @@ -437,7 +430,6 @@ def __init__( source, alignment_source=None, diagnostics=False, - batch_size=512, dtype=None, ): """ @@ -462,9 +454,7 @@ def __init__( "Currently `alignment_src.dtype` must equal `source.dtype`" ) - super().__init__( - composite_basis, source, composite_basis, batch_size=batch_size, dtype=dtype - ) + super().__init__(composite_basis, source, composite_basis, dtype=dtype) def _phase_cross_correlation(self, img0, img1): """ @@ -905,7 +895,6 @@ def __init__( alignment_source=None, radius=None, diagnostics=False, - batch_size=512, dtype=None, ): """ @@ -929,7 +918,6 @@ def __init__( source, alignment_source, diagnostics, - batch_size=batch_size, dtype=dtype, ) From afd498c3bf7398fc92e3420a7dd2a8ae81c6460a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 18 Feb 2022 07:55:30 -0500 Subject: [PATCH 35/40] Docstring and alignment/composite_basis.rotate checks --- src/aspire/classification/averager2d.py | 42 ++++++++++++++++--------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 1510189aa7..4610d51dfd 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -66,10 +66,12 @@ def average( Should return an Image source of synthetic class averages. - :param classes: class indices (refering to src). (n_img, n_nbor). + :param classes: class indices, refering to src. (n_img, n_nbor). :param reflections: Bool representing whether to reflect image in `classes`. - :coefs: Optional basis coefs (could avoid recomputing). - :return: Stack of Synthetic Class Average images as Image instance. + (n_img, n_nbor) + :param coefs: Optional basis coefs (could avoid recomputing). + (n_img, coef_count) + :return: Stack of synthetic class average images as Image instance. """ def _cls_images(self, cls, src=None): @@ -78,9 +80,9 @@ def _cls_images(self, cls, src=None): preserving the class/nbor order. :param cls: An iterable (0/1-D array or list) that holds the indices of images to align. - In Class Averaging, this would be a class. - :param src: Optionally overridee the src, for example, if you want to use a different - source for a certain operation (ie aignment). + In class averaging, this would be a class. + :param src: Optionally override the src, for example, if you want to use a different + source for a certain operation (ie alignment). """ src = src or self.src @@ -101,9 +103,9 @@ class AligningAverager2D(Averager2D): def __init__(self, composite_basis, source, alignment_basis=None, dtype=None): """ - :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D) + :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D). :param source: Source of original images. - :param alignment_basis: Optional, basis to be used only during alignment (eg FSPCA) + :param alignment_basis: Optional, basis to be used only during alignment (eg FSPCA). :param dtype: Numpy dtype to be used during alignment. """ @@ -115,13 +117,13 @@ def __init__(self, composite_basis, source, alignment_basis=None, dtype=None): # If alignment_basis is None, use composite_basis self.alignment_basis = alignment_basis or self.composite_basis - if not hasattr(self.alignment_basis, "rotate"): + if not hasattr(self.composite_basis, "rotate"): raise RuntimeError( - f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `rotate` method." + f"{self.__class__.__name__}'s composite_basis {self.composite_basis} must provide a `rotate` method." ) - if not hasattr(self.alignment_basis, "shift"): + if not hasattr(self.composite_basis, "shift"): raise RuntimeError( - f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `shift` method." + f"{self.__class__.__name__}'s composite_basis {self.composite_basis} must provide a `shift` method." ) @abstractmethod @@ -226,6 +228,11 @@ def __init__( self.n_angles = n_angles + if not hasattr(self.alignment_basis, "rotate"): + raise RuntimeError( + f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `rotate` method." + ) + def align(self, classes, reflections, basis_coefficients): """ Performs the actual rotational alignment estimation, @@ -329,6 +336,11 @@ def __init__( # Each shift will require calling the parent BFRAverager2D.align self._bfr_align = super().align + if not hasattr(self.alignment_basis, "shift"): + raise RuntimeError( + f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `shift` method." + ) + def align(self, classes, reflections, basis_coefficients): """ See `AligningAverager2D.align` @@ -523,9 +535,9 @@ def _reddychatterji(self, images, class_k, reflection_k): This is a util function to help loop over `classes`. - :param images: Image data - :param class_k: Image indices - :param reflection_k: Image reflections + :param images: Image data (m_img, L, L) + :param class_k: Image indices (m_img,) + :param reflection_k: Image reflections (m_img,) :returns: (rotations_k, correlations_k, shifts_k) corresponding to `images` """ From f3f4ec31e4231d07114a5646c6ebf868a9207678 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 18 Feb 2022 08:45:56 -0500 Subject: [PATCH 36/40] More docstring and note changes --- src/aspire/classification/averager2d.py | 36 +++++++++++++------------ 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 4610d51dfd..ae101db4ba 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -66,11 +66,11 @@ def average( Should return an Image source of synthetic class averages. - :param classes: class indices, refering to src. (n_img, n_nbor). + :param classes: class indices, refering to src. (n_classes, n_nbor). :param reflections: Bool representing whether to reflect image in `classes`. - (n_img, n_nbor) + (n_clases, n_nbor) :param coefs: Optional basis coefs (could avoid recomputing). - (n_img, coef_count) + (n_classes, coef_count) :return: Stack of synthetic class average images as Image instance. """ @@ -134,21 +134,21 @@ def align(self, classes, reflections, basis_coefficients): `rotations` is an (n_classes, n_nbor) array of angles, which should represent the rotations needed to align images within - that class. `rotations` is measured in Radians. - - `correlations` is an (n_classes, n_nbor) array representing - a correlation like measure between classified images and their base - image (image index 0). + that class. `rotations` is measured in radians. `shifts` is None or an (n_classes, n_nbor) array of 2D shifts which should represent the translation needed to best align the images within that class. + `correlations` is an (n_classes, n_nbor) array representing + a correlation like measure between classified images and their base + image (image index 0). + Subclasses of should implement and extend this method. - :param classes: (n_classes, n_nbor) integer array of img indices - :param reflections: (n_classes, n_nbor) bool array of corresponding reflections - :param basis_coefficients: (n_img, self.pca_basis.count) compressed basis coefficients + :param classes: (n_classes, n_nbor) integer array of img indices. + :param reflections: (n_classes, n_nbor) bool array of corresponding reflections, + :param basis_coefficients: (n_img, self.alignment_basis.count) basis coefficients, :returns: (rotations, shifts, correlations) """ @@ -216,13 +216,13 @@ def __init__( composite_basis, source, alignment_basis=None, - n_angles=359, + n_angles=360, dtype=None, ): """ See AligningAverager2D, adds: - :params n_angles: Number of brute force rotations to attempt, defaults 359. + :params n_angles: Number of brute force rotations to attempt, defaults 360. """ super().__init__(composite_basis, source, alignment_basis, dtype) @@ -303,7 +303,7 @@ def __init__( composite_basis, source, alignment_basis=None, - n_angles=359, + n_angles=360, n_x_shifts=1, n_y_shifts=1, dtype=None, @@ -318,7 +318,7 @@ def __init__( n_x_shifts=n_y_shifts=0 is the same as calling BFRAverager2D. - :params n_angles: Number of brute force rotations to attempt, defaults 359. + :params n_angles: Number of brute force rotations to attempt, defaults 360. :params n_x_shifts: +- Number of brute force xshifts to attempt, defaults 1. :params n_y_shifts: +- Number of brute force xshifts to attempt, defaults 1. """ @@ -369,7 +369,9 @@ def align(self, classes, reflections, basis_coefficients): shifts = np.empty((*classes.shape, 2), dtype=int) if basis_coefficients is None: - # Retrieve image coefficients, this is bad, but should be deleted anyway. + # Retrieve image coefficients, this is bad, it load all images. + # TODO: Refactor this s.t. the following code blocks and super().align + # only require coefficients relating to their class. See _cls_images. basis_coefficients = self.composite_basis.evaluate_t( self.src.images(0, np.inf) ) @@ -382,7 +384,7 @@ def align(self, classes, reflections, basis_coefficients): # Loop over shift search space, updating best result for x, y in product(x_shifts, y_shifts): shift = np.array([x, y], dtype=int) - logger.debug(f"Computing Rotational alignment after shift ({x},{y}).") + logger.debug(f"Computing rotational alignment after shift ({x},{y}).") # Shift the coef representing the first (base) entry in each class # by the negation of the shift From 0a917ec8b3db6430e15d8d38ff7eab2fc596e579 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 18 Feb 2022 08:56:05 -0500 Subject: [PATCH 37/40] source ~~> src --- src/aspire/classification/averager2d.py | 64 ++++++++++++------------- tests/test_averager2d.py | 2 +- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index ae101db4ba..873bcf958b 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -21,15 +21,15 @@ class Averager2D(ABC): Base class for 2D Image Averaging methods. """ - def __init__(self, composite_basis, source, dtype=None): + def __init__(self, composite_basis, src, dtype=None): """ :param composite_basis: Basis to be used during class average composition (eg FFB2D) - :param source: Source of original images. + :param src: Source of original images. :param dtype: Numpy dtype to be used during alignment. """ self.composite_basis = composite_basis - self.src = source + self.src = src if dtype is None: if self.composite_basis: self.dtype = self.composite_basis.dtype @@ -101,17 +101,17 @@ class AligningAverager2D(Averager2D): Subclass supporting averagers which perfom an aligning stage. """ - def __init__(self, composite_basis, source, alignment_basis=None, dtype=None): + def __init__(self, composite_basis, src, alignment_basis=None, dtype=None): """ :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D). - :param source: Source of original images. + :param src: Source of original images. :param alignment_basis: Optional, basis to be used only during alignment (eg FSPCA). :param dtype: Numpy dtype to be used during alignment. """ super().__init__( composite_basis=composite_basis, - source=source, + src=src, dtype=dtype, ) # If alignment_basis is None, use composite_basis @@ -214,7 +214,7 @@ class BFRAverager2D(AligningAverager2D): def __init__( self, composite_basis, - source, + src, alignment_basis=None, n_angles=360, dtype=None, @@ -224,7 +224,7 @@ def __init__( :params n_angles: Number of brute force rotations to attempt, defaults 360. """ - super().__init__(composite_basis, source, alignment_basis, dtype) + super().__init__(composite_basis, src, alignment_basis, dtype) self.n_angles = n_angles @@ -301,7 +301,7 @@ class BFSRAverager2D(BFRAverager2D): def __init__( self, composite_basis, - source, + src, alignment_basis=None, n_angles=360, n_x_shifts=1, @@ -324,7 +324,7 @@ def __init__( """ super().__init__( composite_basis, - source, + src, alignment_basis, n_angles, dtype=dtype, @@ -441,34 +441,32 @@ class ReddyChatterjiAverager2D(AligningAverager2D): def __init__( self, composite_basis, - source, - alignment_source=None, + src, + alignment_src=None, diagnostics=False, dtype=None, ): """ :param composite_basis: Basis to be used during class average composition. - :param source: Source of original images. - :param alignment_source: Optional, source to be used during class average alignment. - Must be the same resolution as `source`. + :param src: Source of original images. + :param alignment_src: Optional, source to be used during class average alignment. + Must be the same resolution as `src`. :param dtype: Numpy dtype to be used during alignment. """ self.__cache = dict() self.diagnostics = diagnostics self.do_cross_corr_translations = True - self.alignment_src = alignment_source or source + self.alignment_src = alignment_src or src # TODO, for accomodating different resolutions we minimally need to adapt shifting. # Outside of scope right now, but would make a nice PR later. - if self.alignment_src.L != source.L: - raise RuntimeError("Currently `alignment_src.L` must equal `source.L`") - if self.alignment_src.dtype != source.dtype: - raise RuntimeError( - "Currently `alignment_src.dtype` must equal `source.dtype`" - ) + if self.alignment_src.L != src.L: + raise RuntimeError("Currently `alignment_src.L` must equal `src.L`") + if self.alignment_src.dtype != src.dtype: + raise RuntimeError("Currently `alignment_src.dtype` must equal `src.dtype`") - super().__init__(composite_basis, source, composite_basis, dtype=dtype) + super().__init__(composite_basis, src, composite_basis, dtype=dtype) def _phase_cross_correlation(self, img0, img1): """ @@ -905,8 +903,8 @@ class BFSReddyChatterjiAverager2D(ReddyChatterjiAverager2D): def __init__( self, composite_basis, - source, - alignment_source=None, + src, + alignment_src=None, radius=None, diagnostics=False, dtype=None, @@ -914,13 +912,13 @@ def __init__( """ :param alignment_basis: Basis to be used during alignment. For current implementation of ReddyChatterjiAverager2D this should be `None`. - Instead see `alignment_source`. - :param source: Source of original images. + Instead see `alignment_src`. + :param src: Source of original images. :param composite_basis: Basis to be used during class average composition. - :param alignment_source: Optional, source to be used during class average alignment. - Must be the same resolution as `source`. + :param alignment_src: Optional, source to be used during class average alignment. + Must be the same resolution as `src`. :param radius: Brute force translation search radius. - Defaults to source.L//8. + Defaults to src.L//8. :param dtype: Numpy dtype to be used during alignment. :param diagnostics: Plot interactive diagnostic graphics (for debugging). @@ -929,8 +927,8 @@ def __init__( super().__init__( composite_basis, - source, - alignment_source, + src, + alignment_src, diagnostics, dtype=dtype, ) @@ -938,7 +936,7 @@ def __init__( # For brute force we disable the cross_corr translation code self.do_cross_corr_translations = False # Assign search radius - self.radius = radius or source.L // 8 + self.radius = radius or src.L // 8 def align(self, classes, reflections, basis_coefficients): """ diff --git a/tests/test_averager2d.py b/tests/test_averager2d.py index 2c54d07706..10b46e5ac5 100644 --- a/tests/test_averager2d.py +++ b/tests/test_averager2d.py @@ -257,7 +257,7 @@ def testAverager(self): # Construct the Averager and then call the main `align` method avgr = self.averager( composite_basis=self.basis, - source=self._getSrc(), + src=self._getSrc(), dtype=self.dtype, ) _rotations, _shifts, _ = avgr.align(self.classes, self.reflections, self.coefs) From 2b144d6bde54918b8b3fbcb7dde0ad9181daaabd Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 18 Feb 2022 09:02:18 -0500 Subject: [PATCH 38/40] fix __cache of fft in RC methods --- src/aspire/classification/averager2d.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 873bcf958b..1a206a5325 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -479,7 +479,10 @@ def _phase_cross_correlation(self, img0, img1): # Cache img0 transform, this saves n_classes*(n_nbor-1) transforms # Note we use the `id` because ndarray are unhashable - src_f = self.__cache.setdefault(id(img0), fft.fft2(img0)) + key = id(img0) + if key not in self.__cache: + self.__cache[key] = fft.fft2(img0) + src_f = self.__cache[key] target_f = fft.fft2(img1) From 638776ac8637b01062163bd7be5f826ae861a284 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 18 Feb 2022 09:22:27 -0500 Subject: [PATCH 39/40] Update class2d docstring and add comment about log polar --- src/aspire/classification/averager2d.py | 9 ++++++++- src/aspire/classification/class2d.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 1a206a5325..5db97f1aa3 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -621,10 +621,17 @@ def _reddychatterji(self, images, class_k, reflection_k): # Compute the Cross_Correlation to estimate rotation # Note that _phase_cross_correlation uses the mangnitudes (abs()), # ie it is using both freq and phase information. - cross_correlation, shift = self._phase_cross_correlation( + cross_correlation, _ = self._phase_cross_correlation( warped_fixed_img_fs, warped_regis_img_fs ) + # Rotating Cartesian space translates the angular log polar component. + # Scaling Cartesian space translates the radial log polar component. + # In common image resgistration problems, both components are used + # to simultaneously estimate scaling and rotation. + # Since we are not currently concerned with scaling transformation, + # disregard the second axis of the `cross_correlation` returned by + # `_phase_cross_correlation`. cross_correlation_score = cross_correlation[:, 0].ravel() self._rotation_cross_corr_diagnostic( diff --git a/src/aspire/classification/class2d.py b/src/aspire/classification/class2d.py index 9df27fa0d4..1542c12ac3 100644 --- a/src/aspire/classification/class2d.py +++ b/src/aspire/classification/class2d.py @@ -54,5 +54,5 @@ def classify(self): @abstractmethod def averages(self, classes, refl, distances): """ - Returns class averages using prescribed `aligner`. + Returns class averages. """ From d51b0ecb734682af7315f04463232c055a1db87c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 18 Feb 2022 10:21:12 -0500 Subject: [PATCH 40/40] update simulated pipeline noise var after whitening change merged in --- gallery/experiments/simulated_abinitio_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index d9906fedaa..f493f2c4ba 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -52,7 +52,7 @@ num_imgs = 10000 # How many images in our source. n_classes = 1000 # How many class averages to compute. n_nbor = 10 # How many neighbors to stack -noise_variance = 1e-4 # Set a target noise variance +noise_variance = 5e-7 # Set a target noise variance # %%