diff --git a/docs/source/conf.py b/docs/source/conf.py index bfa9c5edf5..67e476e087 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,7 +49,7 @@ 'gallery_dirs': ['auto_tutorials', 'auto_experiments'], # path to where to save gallery generated output 'download_all_examples': False, 'within_subsection_order': ExampleTitleSortKey, - 'filename_pattern': '/*.py', + 'filename_pattern': r'/tutorials/.*\.py', # Parse all gallery python files, but only execute tutorials. } # Add any paths that contain templates here, relative to this directory. diff --git a/gallery/experiments/README.rst b/gallery/experiments/README.rst index 5791749161..8531b9f7b0 100644 --- a/gallery/experiments/README.rst +++ b/gallery/experiments/README.rst @@ -1,4 +1,4 @@ -Experiments - **COMING SOON** +Experiments ============================= This gallery will be for demonstrating the functionality of ASPIRE tools using experimental data. diff --git a/gallery/experiments/experimental_abinitio_pipeline.py b/gallery/experiments/experimental_abinitio_pipeline.py new file mode 100644 index 0000000000..1851ce1ca4 --- /dev/null +++ b/gallery/experiments/experimental_abinitio_pipeline.py @@ -0,0 +1,190 @@ +""" +Abinitio Pipeline - Experimental Data +===================================== + +This notebook introduces a selection of +components corresponding to loading real Relion picked +particle Cryo-EM data and running key ASPIRE-Python +Abinitio model components as a pipeline. + +Specifically this pipeline uses the +EMPIAR 10028 picked particles data, available here: + +https://www.ebi.ac.uk/empiar/EMPIAR-10028 + +https://www.ebi.ac.uk/emdb/EMD-10028 +""" + +# %% +# Imports +# ------- +# First import some of the usual suspects. +# In addition, import some classes from +# the ASPIRE package that will be used throughout this experiment. + +import logging + +import matplotlib.pyplot as plt +import numpy as np + +from aspire.abinitio import CLSyncVoting +from aspire.basis import FFBBasis2D, FFBBasis3D +from aspire.classification import BFSReddyChatterjiAverager2D, RIRClass2D +from aspire.denoising import DenoiserCov2D +from aspire.noise import AnisotropicNoiseEstimator +from aspire.reconstruction import MeanEstimator +from aspire.source import RelionSource + +logger = logging.getLogger(__name__) + + +# %% +# Parameters +# --------------- +# Example simulation configuration. + +interactive = False # Draw blocking interactive plots? +do_cov2d = True # Use CWF coefficients +n_imgs = 20000 # Set to None for all images in starfile, can set smaller for tests. +img_size = 32 # Downsample the images/reconstruction to a desired resolution +n_classes = 1000 # How many class averages to compute. +n_nbor = 50 # How many neighbors to stack +starfile_in = "10028/data/shiny_2sets.star" +volume_filename_prefix_out = f"10028_recon_c{n_classes}_m{n_nbor}_{img_size}.mrc" +pixel_size = 1.34 + +# %% +# Source data and Preprocessing +# ----------------------------- +# +# `RelionSource` is used to access the experimental data via a `starfile`. +# Begin by downsampling to our chosen resolution, then preprocess +# to correct for CTF and noise. + +# Create a source object for the experimental images +src = RelionSource(starfile_in, pixel_size=pixel_size, max_rows=n_imgs) + +# Downsample the images +logger.info(f"Set the resolution to {img_size} X {img_size}") +src.downsample(img_size) + +# Peek +if interactive: + src.images(0, 10).show() + +# Use phase_flip to attempt correcting for CTF. +logger.info("Perform phase flip to input images.") +src.phase_flip() + +# Estimate the noise and `Whiten` based on the estimated noise +aiso_noise_estimator = AnisotropicNoiseEstimator(src) +src.whiten(aiso_noise_estimator.filter) + +# Plot the noise profile for inspection +if interactive: + plt.imshow(aiso_noise_estimator.filter.evaluate_grid(img_size)) + plt.show() + +# Peek, what do the whitened images look like... +if interactive: + src.images(0, 10).show() + +# # Optionally invert image contrast, depends on data convention. +# # This is not needed for 10028, but included anyway. +# logger.info("Invert the global density contrast") +# src.invert_contrast() + +# %% +# Optional: CWF Denoising +# ----------------------- +# +# Optionally generate an alternative source that is denoised with `cov2d`, +# then configure a customized averager. This allows the use of CWF denoised +# images for classification, but stacks the original images for averages +# used in the remainder of the reconstruction pipeline. +# +# In this example, this behavior is controlled by the `do_cov2d` boolean variable. +# When disabled, the original src and default averager is used. +# If you will not be using cov2d, +# you may remove this code block and associated variables. + +classification_src = src +custom_averager = None +if do_cov2d: + # Use CWF denoising + cwf_denoiser = DenoiserCov2D(src) + # Use denoised src for classification + classification_src = cwf_denoiser.denoise() + # Peek, what do the denoised images look like... + if interactive: + classification_src.images(0, 10).show() + + # Use regular `src` for the alignment and composition (averaging). + composite_basis = FFBBasis2D((src.L,) * 2, dtype=src.dtype) + custom_averager = BFSReddyChatterjiAverager2D(composite_basis, src, dtype=src.dtype) + + +# %% +# Class Averaging +# ---------------------- +# +# Now perform classification and averaging for each class. + +logger.info("Begin Class Averaging") + +rir = RIRClass2D( + classification_src, # Source used for classification + fspca_components=400, + bispectrum_components=300, # Compressed Features after last PCA stage. + n_nbor=n_nbor, + n_classes=n_classes, + large_pca_implementation="legacy", + nn_implementation="sklearn", + bispectrum_implementation="legacy", + averager=custom_averager, +) + +classes, reflections, distances = rir.classify() +avgs = rir.averages(classes, reflections, distances) +if interactive: + avgs.images(0, 10).show() + +# %% +# Common Line Estimation +# ---------------------- +# +# Next create a CL instance for estimating orientation of projections +# using the Common Line with Synchronization Voting method. + +logger.info("Begin Orientation Estimation") + +orient_est = CLSyncVoting(avgs, n_theta=36) +# Get the estimated rotations +orient_est.estimate_rotations() +rots_est = orient_est.rotations + +# %% +# Volume Reconstruction +# ---------------------- +# +# Using the estimated rotations, attempt to reconstruct a volume. + +logger.info("Begin Volume reconstruction") + +# Assign the estimated rotations to the class averages +avgs.rots = rots_est + +# Create a reasonable Basis for the 3d Volume +basis = FFBBasis3D((img_size,) * 3, dtype=src.dtype) + +# Setup an estimator to perform the back projection. +estimator = MeanEstimator(avgs, basis) + +# Perform the estimation and save the volume. +estimated_volume = estimator.estimate() +estimated_volume.save(volume_filename_prefix_out, overwrite=True) + +# Peek at result +if interactive: + plt.imshow(np.sum(estimated_volume[0], axis=-1)) + plt.show() diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py new file mode 100644 index 0000000000..f493f2c4ba --- /dev/null +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -0,0 +1,241 @@ +""" +Abinitio Pipeline - Simulated Data +================================== + +This notebook introduces a selection of +components corresponding to generating realistic +simulated Cryo-EM data and running key ASPIRE-Python +Abinitio model components as a pipeline. +""" + +# %% +# Imports +# ------- +# First import some of the usual suspects. +# In addition, import some classes from +# the ASPIRE package that will be used throughout this experiment. + +import logging + +import matplotlib.pyplot as plt +import numpy as np + +from aspire.abinitio import CLSyncVoting +from aspire.basis import FFBBasis2D, FFBBasis3D +from aspire.classification import BFSReddyChatterjiAverager2D, RIRClass2D +from aspire.denoising import DenoiserCov2D +from aspire.noise import AnisotropicNoiseEstimator +from aspire.operators import FunctionFilter, RadialCTFFilter +from aspire.reconstruction import MeanEstimator +from aspire.source import Simulation +from aspire.utils.coor_trans import ( + get_aligned_rotations, + get_rots_mse, + register_rotations, +) +from aspire.volume import Volume + +logger = logging.getLogger(__name__) + + +# %% +# Parameters +# --------------- +# Some example simulation configurations. +# Small sim: img_size 32, num_imgs 10000, n_classes 1000, n_nbor 10 +# Medium sim: img_size 64, num_imgs 20000, n_classes 2000, n_nbor 10 +# Large sim: img_size 129, num_imgs 30000, n_classes 2000, n_nbor 20 + +interactive = True # Draw blocking interactive plots? +do_cov2d = False # Use CWF coefficients +img_size = 32 # Downsample the volume to a desired resolution +num_imgs = 10000 # How many images in our source. +n_classes = 1000 # How many class averages to compute. +n_nbor = 10 # How many neighbors to stack +noise_variance = 5e-7 # Set a target noise variance + + +# %% +# Simulation Data +# --------------- +# Start with a fairly hi-res volume available from EMPIAR/EMDB. +# https://www.ebi.ac.uk/emdb/EMD-2660 +# https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-2660/map/emd_2660.map.gz +og_v = Volume.load("emd_2660.map", dtype=np.float64) +logger.info("Original volume map data" f" shape: {og_v.shape} dtype:{og_v.dtype}") + +logger.info(f"Downsampling to {(img_size,)*3}") +v = og_v.downsample(img_size) +L = v.resolution + + +# Then create a filter based on that variance +# This is an example of a custom noise profile +def noise_function(x, y): + alpha = 1 + beta = 1 + # White + f1 = noise_variance + # Violet-ish + f2 = noise_variance * (x * x + y * y) / L * L + return (alpha * f1 + beta * f2) / 2.0 + + +custom_noise_filter = FunctionFilter(noise_function) + +logger.info("Initialize CTF filters.") +# Create some CTF effects +pixel_size = 5 * 65 / img_size # Pixel size of the images (in angstroms) +voltage = 200 # Voltage (in KV) +defocus_min = 1.5e4 # Minimum defocus value (in angstroms) +defocus_max = 2.5e4 # Maximum defocus value (in angstroms) +defocus_ct = 7 # Number of defocus groups. +Cs = 2.0 # Spherical aberration +alpha = 0.1 # Amplitude contrast + +# Create filters +ctf_filters = [ + RadialCTFFilter(pixel_size, voltage, defocus=d, Cs=2.0, alpha=0.1) + for d in np.linspace(defocus_min, defocus_max, defocus_ct) +] + +# Finally create the Simulation +src = Simulation( + L=v.resolution, + n=num_imgs, + vols=v, + noise_filter=custom_noise_filter, + unique_filters=ctf_filters, +) +# Peek +if interactive: + src.images(0, 10).show() + +# Use phase_flip to attempt correcting for CTF. +logger.info("Perform phase flip to input images.") +src.phase_flip() + +# Estimate the noise and `Whiten` based on the estimated noise +aiso_noise_estimator = AnisotropicNoiseEstimator(src) +src.whiten(aiso_noise_estimator.filter) + +# Plot the noise profile for inspection +if interactive: + plt.imshow(aiso_noise_estimator.filter.evaluate_grid(L)) + plt.show() + +# Peek, what do the whitened images look like... +if interactive: + src.images(0, 10).show() + +# Cache to memory for some speedup +src.cache() + +# %% +# Optional: CWF Denoising +# ----------------------- +# +# Optionally generate an alternative source that is denoised with `cov2d`, +# then configure a customized aligner. This allows the use of CWF denoised +# images for classification, but stacks the original images for averages +# used in the remainder of the reconstruction pipeline. +# +# In this example, this behavior is controlled by the `do_cov2d` boolean variable. +# When disabled, the original src and default aligner is used. +# If you will not be using cov2d, +# you may remove this code block and associated variables. + +classification_src = src +custom_averager = None +if do_cov2d: + # Use CWF denoising + cwf_denoiser = DenoiserCov2D(src) + # Use denoised src for classification + classification_src = cwf_denoiser.denoise() + # Peek, what do the denoised images look like... + if interactive: + classification_src.images(0, 10).show() + + # Use regular `src` for the alignment and composition (averaging). + composite_basis = FFBBasis2D((src.L,) * 2, dtype=src.dtype) + custom_averager = BFSReddyChatterjiAverager2D(composite_basis, src, dtype=src.dtype) + + +# %% +# Class Averaging +# ---------------------- +# +# Now perform classification and averaging for each class. + +logger.info("Begin Class Averaging") + +rir = RIRClass2D( + classification_src, # Source used for classification + fspca_components=400, + bispectrum_components=300, # Compressed Features after last PCA stage. + n_nbor=n_nbor, + n_classes=n_classes, + large_pca_implementation="legacy", + nn_implementation="sklearn", + bispectrum_implementation="legacy", + averager=custom_averager, +) + +classes, reflections, distances = rir.classify() +avgs = rir.averages(classes, reflections, distances) +if interactive: + avgs.images(0, 10).show() + +# %% +# Common Line Estimation +# ---------------------- +# +# Next create a CL instance for estimating orientation of projections +# using the Common Line with Synchronization Voting method. + +logger.info("Begin Orientation Estimation") + +# Stash true rotations for later comparison, +# note this line only works with naive class selection... +true_rotations = src.rots[:n_classes] + +orient_est = CLSyncVoting(avgs, n_theta=36) +# Get the estimated rotations +orient_est.estimate_rotations() +rots_est = orient_est.rotations + +logger.info("Compare with known rotations") +# Compare with known true rotations +Q_mat, flag = register_rotations(rots_est, true_rotations) +regrot = get_aligned_rotations(rots_est, Q_mat, flag) +mse_reg = get_rots_mse(regrot, true_rotations) +logger.info( + f"MSE deviation of the estimated rotations using register_rotations : {mse_reg}\n" +) + +# %% +# Volume Reconstruction +# ---------------------- +# +# Using the estimated rotations, attempt to reconstruct a volume. + +logger.info("Begin Volume reconstruction") + +# Assign the estimated rotations to the class averages +avgs.rots = rots_est + +# Create a reasonable Basis for the 3d Volume +basis = FFBBasis3D((v.resolution,) * 3, dtype=v.dtype) + +# Setup an estimator to perform the back projection. +estimator = MeanEstimator(avgs, basis) + +# Perform the estimation and save the volume. +estimated_volume = estimator.estimate() +fn = f"estimated_volume_n{num_imgs}_c{n_classes}_m{n_nbor}_{img_size}.mrc" +estimated_volume.save(fn, overwrite=True) + +# Peek at result +if interactive: + plt.imshow(np.sum(estimated_volume[0], axis=-1)) + plt.show() diff --git a/gallery/tutorials/class_averaging.py b/gallery/tutorials/class_averaging.py index dcf1272b60..6d449a24d2 100644 --- a/gallery/tutorials/class_averaging.py +++ b/gallery/tutorials/class_averaging.py @@ -116,13 +116,13 @@ bispectrum_implementation="legacy", ) -classes, reflections, rotations, shifts, corr = rir.classify() +classes, reflections, dists = rir.classify() +avgs = rir.averages(classes, reflections, dists) # %% # Display Classes # ^^^^^^^^^^^^^^^ -avgs = rir.output(classes, reflections, rotations) avgs.images(0, 10).show() # %% @@ -169,13 +169,13 @@ bispectrum_implementation="legacy", ) -classes, reflections, rotations, shifts, corr = noisy_rir.classify() +classes, reflections, dists = noisy_rir.classify() +avgs = noisy_rir.averages(classes, reflections, dists) # %% # Display Classes # ^^^^^^^^^^^^^^^ -avgs = noisy_rir.output(classes, reflections, rotations) avgs.images(0, 10).show() @@ -196,17 +196,5 @@ # Report the identified neighbors Image(noisy_src.images(0, np.inf)[classes[review_class]]).show() -# Report their associated rots_refls -rots_refls = ["index, Rotation, Reflection"] -for i in range(classes.shape[1]): - rots_refls.append( - f"{i}, {rotations[review_class, i] * 180 / np.pi}, {reflections[review_class, i]}" - ) -rots_refls = "\n".join(rots_refls) - -logger.info( - f"Class {review_class}'s estimated Rotations and Reflections:\n{rots_refls}" -) - # Display the averaged result avgs.images(review_class, 1).show() diff --git a/setup.py b/setup.py index e2f1f25e75..a3507366d5 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ def read(fname): "pillow", "scipy==1.7.3", "scikit-learn", + "scikit-image", "setuptools>=0.41", "tqdm", ], diff --git a/src/aspire/basis/basis.py b/src/aspire/basis/basis.py index ddafca620e..0b2fe01f5d 100644 --- a/src/aspire/basis/basis.py +++ b/src/aspire/basis/basis.py @@ -4,8 +4,10 @@ from scipy.sparse.linalg import LinearOperator, cg from aspire.basis.basis_utils import num_besselj_zeros +from aspire.image import Image from aspire.utils import ensure, mdim_mat_fun_conj from aspire.utils.matlab_compat import m_reshape +from aspire.volume import Volume logger = logging.getLogger(__name__) @@ -174,6 +176,10 @@ def expand(self, x): those first dimensions of `x`. """ + + if isinstance(x, Image) or isinstance(x, Volume): + x = x.asnumpy() + # ensure the first dimensions with size of self.sz sz_roll = x.shape[: -self.ndim] @@ -206,5 +212,5 @@ def expand(self, x): raise RuntimeError("Unable to converge!") # return v coefficients with the last dimension of self.count - v = v.reshape((-1, *sz_roll)) + v = v.reshape((*sz_roll, -1)) return v diff --git a/src/aspire/classification/__init__.py b/src/aspire/classification/__init__.py index f11ad1d8e0..b1aea6ea5a 100644 --- a/src/aspire/classification/__init__.py +++ b/src/aspire/classification/__init__.py @@ -1,3 +1,12 @@ -from .align2d import Align2D, BFRAlign2D, BFSRAlign2D, EMAlign2D, FTKAlign2D +from .averager2d import ( + AligningAverager2D, + Averager2D, + BFRAverager2D, + BFSRAverager2D, + BFSReddyChatterjiAverager2D, + EMAverager2D, + FTKAverager2D, + ReddyChatterjiAverager2D, +) from .class2d import Class2D from .rir_class2d import RIRClass2D diff --git a/src/aspire/classification/align2d.py b/src/aspire/classification/align2d.py deleted file mode 100644 index bfada60c87..0000000000 --- a/src/aspire/classification/align2d.py +++ /dev/null @@ -1,259 +0,0 @@ -import logging -from itertools import product - -import numpy as np -from tqdm import trange - -logger = logging.getLogger(__name__) - - -class Align2D: - """ - Base class for 2D Image Alignment methods. - """ - - def __init__(self, basis, dtype): - """ - :param basis: Basis to be used for any methods during alignment. - :param dtype: Numpy dtype to be used during alignment. - """ - - self.basis = basis - if dtype is None: - self.dtype = self.basis.dtype - else: - self.dtype = np.dtype(dtype) - if self.dtype != self.basis.dtype: - logger.warning( - f"Align2D basis.dtype {self.basis.dtype} does not match self.dtype {self.dtype}." - ) - - def align(self, classes, reflections, basis_coefficients): - """ - Any align2D alignment method should take in the following arguments - and return the described tuple. - - Generally, the returned `classes` and `reflections` should be same as - the input. They are passed through for convience, - considering they would all be required for image output. - - Returned `rotations` is an (n_classes, n_nbor) array of angles, - which should represent the rotations needed to align images within - that class. `rotations` is measured in Radians. - - Returned `correlations` is an (n_classes, n_nbor) array representing - a correlation like measure between classified images and their base - image (image index 0). - - Returned `shifts` is None or an (n_classes, n_nbor) array of 2D shifts - which should represent the translation needed to best align the images - within that class. - - Subclasses of `align` should extend this method with optional arguments. - - :param classes: (n_classes, n_nbor) integer array of img indices - :param refl: (n_classes, n_nbor) bool array of corresponding reflections - :param coef: (n_img, self.pca_basis.count) compressed basis coefficients - - :returns: (classes, reflections, rotations, shifts, correlations) - """ - raise NotImplementedError("Subclasses must implement align.") - - -class BFRAlign2D(Align2D): - """ - This perfoms a Brute Force Rotational alignment. - - For each class, - constructs n_angles rotations of all class members, - and then identifies angle yielding largest correlation(dot). - """ - - def __init__(self, basis, n_angles=359, dtype=None): - """ - :params basis: Basis providing a `rotate` method. - :params n_angles: Number of brute force rotations to attempt, defaults 359. - """ - super().__init__(basis, dtype) - - self.n_angles = n_angles - - if not hasattr(self.basis, "rotate"): - raise RuntimeError( - f"BFRAlign2D's basis {self.basis} must provide a `rotate` method." - ) - - def align(self, classes, reflections, basis_coefficients): - """ - See `Align2D.align` - """ - # Admit simple case of single case alignment - classes = np.atleast_2d(classes) - reflections = np.atleast_2d(reflections) - - n_classes, n_nbor = classes.shape - - # Construct array of angles to brute force. - test_angles = np.linspace(0, 2 * np.pi, self.n_angles, endpoint=False) - - # Instantiate matrices for results - rotations = np.empty(classes.shape, dtype=self.dtype) - correlations = np.empty(classes.shape, dtype=self.dtype) - results = np.empty((n_nbor, self.n_angles)) - - for k in trange(n_classes): - - # Get the coefs for these neighbors - nbr_coef = basis_coefficients[classes[k]] - - for i, angle in enumerate(test_angles): - # Rotate the set of neighbors by angle, - rotated_nbrs = self.basis.rotate(nbr_coef, angle, reflections[k]) - - # then store dot between class base image (0) and each nbor - for j, nbor in enumerate(rotated_nbrs): - results[j, i] = np.dot(nbr_coef[0], nbor) - - # Now along each class, find the index of the angle reporting highest correlation - angle_idx = np.argmax(results, axis=1) - - # Store that angle as our rotation for this image - rotations[k, :] = test_angles[angle_idx] - - # Also store the correlations for each neighbor - for j in range(n_nbor): - correlations[k, j] = results[j, angle_idx[j]] - - # None is placeholder for shifts - return classes, reflections, rotations, None, correlations - - -class BFSRAlign2D(BFRAlign2D): - """ - This perfoms a Brute Force Shift and Rotational alignment. - It is potentially expensive to brute force this search space. - - For each pair of x_shifts and y_shifts, - Perform BFR - - Return the rotation and shift yielding the best results. - """ - - def __init__(self, basis, n_angles=359, n_x_shifts=1, n_y_shifts=1, dtype=None): - """ - Note that n_x_shifts and n_y_shifts are the number of shifts to perform - in each direction. - - Example: n_x_shifts=1, n_y_shifts=0 would test {-1,0,1} X {0}. - - n_x_shifts=n_y_shifts=0 is the same as calling BFRAlign2D. - - :params basis: Basis providing a `shift` and `rotate` method. - :params n_angles: Number of brute force rotations to attempt, defaults 359. - :params n_x_shifts: +- Number of brute force xshifts to attempt, defaults 1. - :params n_y_shifts: +- Number of brute force xshifts to attempt, defaults 1. - """ - super().__init__(basis, n_angles, dtype) - - self.n_x_shifts = n_x_shifts - self.n_y_shifts = n_y_shifts - - if not hasattr(self.basis, "shift"): - raise RuntimeError( - f"BFSRAlign2D's basis {self.basis} must provide a `shift` method." - ) - - # Each shift will require calling the parent BFRAlign2D.align - self._bfr_align = super().align - - def align(self, classes, reflections, basis_coefficients): - """ - See `Align2D.align` - """ - - # Admit simple case of single case alignment - classes = np.atleast_2d(classes) - reflections = np.atleast_2d(reflections) - - n_classes = classes.shape[0] - - # Compute the shifts. Roll array so 0 is first. - x_shifts = np.roll( - np.arange(-self.n_x_shifts, self.n_x_shifts + 1), -self.n_x_shifts - ) - y_shifts = np.roll( - np.arange(-self.n_y_shifts, self.n_y_shifts + 1), -self.n_y_shifts - ) - # Above rolls should force initial pair of shifts to (0,0). - # This is done primarily in case of a tie later we would take unshifted. - assert (x_shifts[0], y_shifts[0]) == (0, 0) - - # These arrays will incrementally store our best alignment. - rotations = np.empty(classes.shape, dtype=self.dtype) - correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf - shifts = np.empty((*classes.shape, 2), dtype=int) - - # We want to maintain the original coefs for the base images, - # because we will mutate them with shifts in the loop. - original_coef = basis_coefficients[classes[:, 0], :] - assert original_coef.shape == (n_classes, self.basis.count) - - # Loop over shift search space, updating best result - for x, y in product(x_shifts, y_shifts): - shift = np.array([x, y], dtype=int) - logger.info(f"Computing Rotational alignment after shift ({x},{y}).") - - # Shift the coef representing the first (base) entry in each class - # by the negation of the shift - # Shifting one image is more efficient than shifting every neighbor - basis_coefficients[classes[:, 0], :] = self.basis.shift( - original_coef, -shift - ) - - _, _, _rotations, _, _correlations = self._bfr_align( - classes, reflections, basis_coefficients - ) - - # Each class-neighbor pair may have a best shift-rot from a different shift. - # Test and update - improved_indices = _correlations > correlations - rotations[improved_indices] = _rotations[improved_indices] - correlations[improved_indices] = _correlations[improved_indices] - shifts[improved_indices] = shift - - # Restore unshifted base coefs - basis_coefficients[classes[:, 0], :] = original_coef - - if (x, y) == (0, 0): - logger.info("Initial rotational alignment complete (shift (0,0))") - assert np.sum(improved_indices) == np.size( - classes - ), f"{np.sum(improved_indices)} =?= {np.size(classes)}" - else: - logger.info( - f"Shift ({x},{y}) complete. Improved {np.sum(improved_indices)} alignments." - ) - - return classes, reflections, rotations, shifts, correlations - - -class EMAlign2D(Align2D): - """ - Citation needed. - """ - - def __init__(self, basis, dtype=None): - super().__init__(basis, dtype) - - -class FTKAlign2D(Align2D): - """ - Factorization of the translation kernel for fast rigid image alignment. - Rangan, A.V., Spivak, M., Anden, J., & Barnett, A.H. (2019). - """ - - def __init__(self, basis, dtype=None): - super().__init__(basis, dtype) - - def align(self, classes, reflections, basis_coefficients): - raise NotImplementedError diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py new file mode 100644 index 0000000000..5db97f1aa3 --- /dev/null +++ b/src/aspire/classification/averager2d.py @@ -0,0 +1,1028 @@ +import logging +from abc import ABC, abstractmethod +from itertools import product + +import matplotlib.pyplot as plt +import numpy as np +from skimage.filters import difference_of_gaussians, window +from skimage.transform import rotate, warp_polar +from tqdm import tqdm, trange + +from aspire.image import Image +from aspire.numeric import fft +from aspire.source import ArrayImageSource +from aspire.utils.coor_trans import grid_2d + +logger = logging.getLogger(__name__) + + +class Averager2D(ABC): + """ + Base class for 2D Image Averaging methods. + """ + + def __init__(self, composite_basis, src, dtype=None): + """ + :param composite_basis: Basis to be used during class average composition (eg FFB2D) + :param src: Source of original images. + :param dtype: Numpy dtype to be used during alignment. + """ + + self.composite_basis = composite_basis + self.src = src + if dtype is None: + if self.composite_basis: + self.dtype = self.composite_basis.dtype + elif self.src: + self.dtype = self.src.dtype + else: + raise RuntimeError("You must supply a basis/src/dtype.") + else: + self.dtype = np.dtype(dtype) + + if self.src and self.dtype != self.src.dtype: + logger.warning( + f"{self.__class__.__name__} dtype {dtype}" + "does not match dtype of source {self.src.dtype}." + ) + if self.composite_basis and self.dtype != self.composite_basis.dtype: + logger.warning( + f"{self.__class__.__name__} dtype {dtype}" + "does not match dtype of basis {self.composite_basis.dtype}." + ) + + @abstractmethod + def average( + self, + classes, + reflections, + coefs=None, + ): + """ + Combines images using stacking in `self.composite_basis`. + + Subclasses should implement this. + (Example EM algos use radically different averaging). + + Should return an Image source of synthetic class averages. + + :param classes: class indices, refering to src. (n_classes, n_nbor). + :param reflections: Bool representing whether to reflect image in `classes`. + (n_clases, n_nbor) + :param coefs: Optional basis coefs (could avoid recomputing). + (n_classes, coef_count) + :return: Stack of synthetic class average images as Image instance. + """ + + def _cls_images(self, cls, src=None): + """ + Util to return images as an array for class k (provided as array `cls` ), + preserving the class/nbor order. + + :param cls: An iterable (0/1-D array or list) that holds the indices of images to align. + In class averaging, this would be a class. + :param src: Optionally override the src, for example, if you want to use a different + source for a certain operation (ie alignment). + """ + src = src or self.src + + n_nbor = cls.shape[-1] # Includes zero'th neighbor + + images = np.empty((n_nbor, src.L, src.L), dtype=self.dtype) + + for i, index in enumerate(cls): + images[i] = src.images(index, 1).asnumpy() + + return images + + +class AligningAverager2D(Averager2D): + """ + Subclass supporting averagers which perfom an aligning stage. + """ + + def __init__(self, composite_basis, src, alignment_basis=None, dtype=None): + """ + :param composite_basis: Basis to be used during class average composition (eg hi res Cartesian/FFB2D). + :param src: Source of original images. + :param alignment_basis: Optional, basis to be used only during alignment (eg FSPCA). + :param dtype: Numpy dtype to be used during alignment. + """ + + super().__init__( + composite_basis=composite_basis, + src=src, + dtype=dtype, + ) + # If alignment_basis is None, use composite_basis + self.alignment_basis = alignment_basis or self.composite_basis + + if not hasattr(self.composite_basis, "rotate"): + raise RuntimeError( + f"{self.__class__.__name__}'s composite_basis {self.composite_basis} must provide a `rotate` method." + ) + if not hasattr(self.composite_basis, "shift"): + raise RuntimeError( + f"{self.__class__.__name__}'s composite_basis {self.composite_basis} must provide a `shift` method." + ) + + @abstractmethod + def align(self, classes, reflections, basis_coefficients): + """ + During this process `rotations`, `reflections`, `shifts` and + `correlations` properties will be computed for aligners. + + `rotations` is an (n_classes, n_nbor) array of angles, + which should represent the rotations needed to align images within + that class. `rotations` is measured in radians. + + `shifts` is None or an (n_classes, n_nbor) array of 2D shifts + which should represent the translation needed to best align the images + within that class. + + `correlations` is an (n_classes, n_nbor) array representing + a correlation like measure between classified images and their base + image (image index 0). + + Subclasses of should implement and extend this method. + + :param classes: (n_classes, n_nbor) integer array of img indices. + :param reflections: (n_classes, n_nbor) bool array of corresponding reflections, + :param basis_coefficients: (n_img, self.alignment_basis.count) basis coefficients, + + :returns: (rotations, shifts, correlations) + """ + + def average( + self, + classes, + reflections, + coefs=None, + ): + """ + This subclass assumes we get alignment details from `align` method. Otherwise. see Averager2D.average + """ + + rotations, shifts, _ = self.align(classes, reflections, coefs) + + n_classes, n_nbor = classes.shape + + b_avgs = np.empty((n_classes, self.composite_basis.count), dtype=self.src.dtype) + + for i in tqdm(range(n_classes)): + + # Get coefs in Composite_Basis if not provided as an argumen. + if coefs is None: + # Retrieve relavent images directly from source. + neighbors_imgs = Image(self._cls_images(classes[i])) + + # Do shifts + if shifts is not None: + neighbors_imgs.shift(shifts[i]) + + neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) + else: + # Get the neighbors + neighbors_ids = classes[i] + neighbors_coefs = coefs[neighbors_ids] + if shifts is not None: + neighbors_coefs = self.composite_basis.shift( + neighbors_coefs, shifts[i] + ) + + # Rotate in composite_basis + neighbors_coefs = self.composite_basis.rotate( + neighbors_coefs, rotations[i], reflections[i] + ) + + # Averaging in composite_basis + b_avgs[i] = np.mean(neighbors_coefs, axis=0) + + # Now we convert the averaged images from Basis to Cartesian. + return ArrayImageSource(self.composite_basis.evaluate(b_avgs)) + + +class BFRAverager2D(AligningAverager2D): + """ + This perfoms a Brute Force Rotational alignment. + + For each class, + constructs n_angles rotations of all class members, + and then identifies angle yielding largest correlation(dot). + """ + + def __init__( + self, + composite_basis, + src, + alignment_basis=None, + n_angles=360, + dtype=None, + ): + """ + See AligningAverager2D, adds: + + :params n_angles: Number of brute force rotations to attempt, defaults 360. + """ + super().__init__(composite_basis, src, alignment_basis, dtype) + + self.n_angles = n_angles + + if not hasattr(self.alignment_basis, "rotate"): + raise RuntimeError( + f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `rotate` method." + ) + + def align(self, classes, reflections, basis_coefficients): + """ + Performs the actual rotational alignment estimation, + returning parameters needed for averaging. + """ + + # Admit simple case of single case alignment + classes = np.atleast_2d(classes) + reflections = np.atleast_2d(reflections) + + n_classes, n_nbor = classes.shape + + # Construct array of angles to brute force. + test_angles = np.linspace(0, 2 * np.pi, self.n_angles, endpoint=False) + + # Instantiate matrices for results + rotations = np.empty(classes.shape, dtype=self.dtype) + correlations = np.empty(classes.shape, dtype=self.dtype) + results = np.empty((n_nbor, self.n_angles)) + + for k in trange(n_classes): + + # Get the coefs for these neighbors + if basis_coefficients is None: + # Retrieve relavent images + neighbors_imgs = Image(self._cls_images(classes[k])) + # Evaluate_T into basis + nbr_coef = self.composite_basis.evaluate_t(neighbors_imgs) + else: + nbr_coef = basis_coefficients[classes[k]] + + for i, angle in enumerate(test_angles): + # Rotate the set of neighbors by angle, + rotated_nbrs = self.alignment_basis.rotate( + nbr_coef, angle, reflections[k] + ) + + # then store dot between class base image (0) and each nbor + for j, nbor in enumerate(rotated_nbrs): + results[j, i] = np.dot(nbr_coef[0], nbor) + + # Now along each class, find the index of the angle reporting highest correlation + angle_idx = np.argmax(results, axis=1) + + # Store that angle as our rotation for this image + rotations[k, :] = test_angles[angle_idx] + + # Also store the correlations for each neighbor + for j in range(n_nbor): + correlations[k, j] = results[j, angle_idx[j]] + + return rotations, None, correlations + + +class BFSRAverager2D(BFRAverager2D): + """ + This perfoms a Brute Force Shift and Rotational alignment. + It is potentially expensive to brute force this search space. + + For each pair of x_shifts and y_shifts, + Perform BFR + + Return the rotation and shift yielding the best results. + """ + + def __init__( + self, + composite_basis, + src, + alignment_basis=None, + n_angles=360, + n_x_shifts=1, + n_y_shifts=1, + dtype=None, + ): + """ + See AligningAverager2D and BFRAverager2D, adds: `n_x_shifts`, `n_y_shifts`. + + Note that `n_x_shifts` and `n_y_shifts` are the number of shifts + to perform in each direction. + + Example: n_x_shifts=1, n_y_shifts=0 would test {-1,0,1} X {0}. + + n_x_shifts=n_y_shifts=0 is the same as calling BFRAverager2D. + + :params n_angles: Number of brute force rotations to attempt, defaults 360. + :params n_x_shifts: +- Number of brute force xshifts to attempt, defaults 1. + :params n_y_shifts: +- Number of brute force xshifts to attempt, defaults 1. + """ + super().__init__( + composite_basis, + src, + alignment_basis, + n_angles, + dtype=dtype, + ) + + self.n_x_shifts = n_x_shifts + self.n_y_shifts = n_y_shifts + + # Each shift will require calling the parent BFRAverager2D.align + self._bfr_align = super().align + + if not hasattr(self.alignment_basis, "shift"): + raise RuntimeError( + f"{self.__class__.__name__}'s alignment_basis {self.alignment_basis} must provide a `shift` method." + ) + + def align(self, classes, reflections, basis_coefficients): + """ + See `AligningAverager2D.align` + """ + + # Admit simple case of single case alignment + classes = np.atleast_2d(classes) + reflections = np.atleast_2d(reflections) + + n_classes = classes.shape[0] + + # Compute the shifts. Roll array so 0 is first. + x_shifts = np.roll( + np.arange(-self.n_x_shifts, self.n_x_shifts + 1), -self.n_x_shifts + ) + y_shifts = np.roll( + np.arange(-self.n_y_shifts, self.n_y_shifts + 1), -self.n_y_shifts + ) + # Above rolls should force initial pair of shifts to (0,0). + # This is done primarily in case of a tie later we would take unshifted. + assert (x_shifts[0], y_shifts[0]) == (0, 0) + + # These arrays will incrementally store our best alignment. + rotations = np.empty(classes.shape, dtype=self.dtype) + correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf + shifts = np.empty((*classes.shape, 2), dtype=int) + + if basis_coefficients is None: + # Retrieve image coefficients, this is bad, it load all images. + # TODO: Refactor this s.t. the following code blocks and super().align + # only require coefficients relating to their class. See _cls_images. + basis_coefficients = self.composite_basis.evaluate_t( + self.src.images(0, np.inf) + ) + + # We want to maintain the original coefs for the base images, + # because we will mutate them with shifts in the loop. + original_coef = basis_coefficients[classes[:, 0], :] + assert original_coef.shape == (n_classes, self.alignment_basis.count) + + # Loop over shift search space, updating best result + for x, y in product(x_shifts, y_shifts): + shift = np.array([x, y], dtype=int) + logger.debug(f"Computing rotational alignment after shift ({x},{y}).") + + # Shift the coef representing the first (base) entry in each class + # by the negation of the shift + # Shifting one image is more efficient than shifting every neighbor + basis_coefficients[classes[:, 0], :] = self.alignment_basis.shift( + original_coef, -shift + ) + + _rotations, _, _correlations = self._bfr_align( + classes, reflections, basis_coefficients + ) + + # Each class-neighbor pair may have a best shift-rot from a different shift. + # Test and update + improved_indices = _correlations > correlations + rotations[improved_indices] = _rotations[improved_indices] + correlations[improved_indices] = _correlations[improved_indices] + shifts[improved_indices] = shift + + # Restore unshifted base coefs + basis_coefficients[classes[:, 0], :] = original_coef + + if (x, y) == (0, 0): + logger.debug("Initial rotational alignment complete (shift (0,0))") + assert np.sum(improved_indices) == np.size( + classes + ), f"{np.sum(improved_indices)} =?= {np.size(classes)}" + else: + logger.debug( + f"Shift ({x},{y}) complete. Improved {np.sum(improved_indices)} alignments." + ) + + return rotations, shifts, correlations + + +class ReddyChatterjiAverager2D(AligningAverager2D): + """ + Attempts rotational estimation using Reddy Chatterji log polar Fourier cross correlation. + Then attempts shift (translational) estimation using cross correlation. + + When averaging, performs rotations then shifts. + + Note, it may be possible to iterate this algorithm... + + Adopted from Reddy Chatterji (1996) + An FFT-Based Technique for Translation, + Rotation, and Scale-Invariant Image Registration + IEEE TRANSACTIONS ON IMAGE PROCESSING, VOL. 5, NO. 8, AUGUST 1996 + + This method intentionally does not use any of ASPIRE's basis + so that it may be used as a reference for more ASPIRE approaches. + """ + + def __init__( + self, + composite_basis, + src, + alignment_src=None, + diagnostics=False, + dtype=None, + ): + """ + :param composite_basis: Basis to be used during class average composition. + :param src: Source of original images. + :param alignment_src: Optional, source to be used during class average alignment. + Must be the same resolution as `src`. + :param dtype: Numpy dtype to be used during alignment. + """ + + self.__cache = dict() + self.diagnostics = diagnostics + self.do_cross_corr_translations = True + self.alignment_src = alignment_src or src + + # TODO, for accomodating different resolutions we minimally need to adapt shifting. + # Outside of scope right now, but would make a nice PR later. + if self.alignment_src.L != src.L: + raise RuntimeError("Currently `alignment_src.L` must equal `src.L`") + if self.alignment_src.dtype != src.dtype: + raise RuntimeError("Currently `alignment_src.dtype` must equal `src.dtype`") + + super().__init__(composite_basis, src, composite_basis, dtype=dtype) + + def _phase_cross_correlation(self, img0, img1): + """ + # Adapted from skimage.registration.phase_cross_correlation + + :param img0: Fixed image. + :param img1: Translated image. + :returns: (cross-correlation magnitudes (2D array), shifts) + """ + + # Cache img0 transform, this saves n_classes*(n_nbor-1) transforms + # Note we use the `id` because ndarray are unhashable + key = id(img0) + if key not in self.__cache: + self.__cache[key] = fft.fft2(img0) + src_f = self.__cache[key] + + target_f = fft.fft2(img1) + + # Whole-pixel shifts - Compute cross-correlation by an IFFT + shape = src_f.shape + image_product = src_f * target_f.conj() + cross_correlation = fft.ifft2(image_product) + + # Locate maximum + maxima = np.unravel_index( + np.argmax(np.abs(cross_correlation)), cross_correlation.shape + ) + midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape]) + + shifts = np.array(maxima, dtype=np.float64) + shifts[shifts > midpoints] -= np.array(shape)[shifts > midpoints] + + return np.abs(cross_correlation), shifts + + def align(self, classes, reflections, basis_coefficients): + """ + Performs the actual rotational alignment estimation, + returning parameters needed for averaging. + """ + + # Admit simple case of single case alignment + classes = np.atleast_2d(classes) + reflections = np.atleast_2d(reflections) + + n_classes = classes.shape[0] + + # Instantiate matrices for results + rotations = np.zeros(classes.shape, dtype=self.dtype) + correlations = np.zeros(classes.shape, dtype=self.dtype) + shifts = np.zeros((*classes.shape, 2), dtype=int) + + for k in trange(n_classes): + # # Get the array of images for this class, using the `alignment_src`. + images = self._cls_images(classes[k], src=self.alignment_src) + + rotations[k], shifts[k], correlations[k] = self._reddychatterji( + images, classes[k], reflections[k] + ) + + return rotations, shifts, correlations + + def _reddychatterji(self, images, class_k, reflection_k): + """ + Compute the Reddy Chatterji method registering images[1:] to image[0]. + + This differs from papers and published scikit implimentations by + computing the fixed base image[0] pipeline once then reusing. + + This is a util function to help loop over `classes`. + + :param images: Image data (m_img, L, L) + :param class_k: Image indices (m_img,) + :param reflection_k: Image reflections (m_img,) + :returns: (rotations_k, correlations_k, shifts_k) corresponding to `images` + """ + + # Result arrays + M = len(images) + rotations_k = np.zeros(M, dtype=self.dtype) + correlations_k = np.zeros(M, dtype=self.dtype) + shifts_k = np.zeros((M, 2), dtype=int) + + # De-Mean, note images is mutated and should be a `copy`. + images -= images.mean(axis=(-1, -2))[:, np.newaxis, np.newaxis] + + # Precompute fixed_img data used repeatedly in the loop below. + fixed_img = images[0] + # Difference of Gaussians (Band Filter) + fixed_img_dog = difference_of_gaussians(fixed_img, 1, 4) + # Window Images (Fix spectral boundary) + wfixed_img = fixed_img_dog * window("hann", fixed_img.shape) + # Transform image to Fourier space + fixed_img_fs = np.abs(fft.fftshift(fft.fft2(wfixed_img))) ** 2 + # Compute Log Polar Transform + radius = fixed_img_fs.shape[0] // 8 # Low Pass + warped_fixed_img_fs = warp_polar( + fixed_img_fs, + radius=radius, + output_shape=fixed_img_fs.shape, + scaling="log", + ) + # Only use half of FFT, because it's symmetrical + warped_fixed_img_fs = warped_fixed_img_fs[: fixed_img_fs.shape[0] // 2, :] + + # Now prepare for rotating original images, + # and searching for translations. + # We start back at the raw fixed_img. + twfixed_img = fixed_img * window("hann", fixed_img.shape) + + # Register image `m` against image[0] + for m in range(1, len(images)): + # Get the image to register + regis_img = images[m] + + # Reflect images when necessary + if reflection_k[m]: + regis_img = np.flipud(regis_img) + + # Difference of Gaussians (Band Filter) + regis_img_dog = difference_of_gaussians(regis_img, 1, 4) + + # Window Images (Fix spectral boundary) + wregis_img = regis_img_dog * window("hann", regis_img.shape) + + self._input_images_diagnostic( + class_k[0], wfixed_img, class_k[m], wregis_img + ) + + # Transform image to Fourier space + regis_img_fs = np.abs(fft.fftshift(fft.fft2(wregis_img))) ** 2 + + self._windowed_psd_diagnostic( + class_k[0], fixed_img_fs, class_k[m], regis_img_fs + ) + + # Compute Log Polar Transform + warped_regis_img_fs = warp_polar( + regis_img_fs, + radius=radius, # Low Pass + output_shape=fixed_img_fs.shape, + scaling="log", + ) + + self._log_polar_diagnostic( + class_k[0], warped_fixed_img_fs, class_k[m], warped_regis_img_fs + ) + + # Only use half of FFT, because it's symmetrical + warped_regis_img_fs = warped_regis_img_fs[: fixed_img_fs.shape[0] // 2, :] + + # Compute the Cross_Correlation to estimate rotation + # Note that _phase_cross_correlation uses the mangnitudes (abs()), + # ie it is using both freq and phase information. + cross_correlation, _ = self._phase_cross_correlation( + warped_fixed_img_fs, warped_regis_img_fs + ) + + # Rotating Cartesian space translates the angular log polar component. + # Scaling Cartesian space translates the radial log polar component. + # In common image resgistration problems, both components are used + # to simultaneously estimate scaling and rotation. + # Since we are not currently concerned with scaling transformation, + # disregard the second axis of the `cross_correlation` returned by + # `_phase_cross_correlation`. + cross_correlation_score = cross_correlation[:, 0].ravel() + + self._rotation_cross_corr_diagnostic( + cross_correlation, cross_correlation_score + ) + + # Recover the angle from index representing maximal cross_correlation + recovered_angle_degrees = (360 / regis_img_fs.shape[0]) * np.argmax( + cross_correlation_score + ) + + if recovered_angle_degrees > 90: + r = 180 - recovered_angle_degrees + else: + r = -recovered_angle_degrees + + # For now, try the hack below, attempting two cases ... + # Some papers mention running entire algos /twice/, + # when admitting reflections, so this hack is not + # the worst you could do :). + # Hack + regis_img_estimated = rotate(regis_img, r) + regis_img_rotated_p180 = rotate(regis_img, r + 180) + da = np.dot(fixed_img.flatten(), regis_img_estimated.flatten()) + db = np.dot(fixed_img.flatten(), regis_img_rotated_p180.flatten()) + if db > da: + regis_img_estimated = regis_img_rotated_p180 + r += 180 + + self._rotated_diagnostic( + class_k[0], + fixed_img, + class_k[m], + regis_img_estimated, + reflection_k[m], + r, + ) + + # Assign estimated rotations results + rotations_k[m] = -r * np.pi / 180 # Reverse rot and convert to radians + + if self.do_cross_corr_translations: + # Prepare for searching over translations using cross-correlation with the rotated image. + twregis_img = regis_img_estimated * window("hann", regis_img.shape) + cross_correlation, shift = self._phase_cross_correlation( + twfixed_img, twregis_img + ) + + self._translation_cross_corr_diagnostic(cross_correlation) + + # Compute the shifts as integer number of pixels, + shift_x, shift_y = int(shift[1]), int(shift[0]) + # then apply the shifts + regis_img_estimated = np.roll(regis_img_estimated, shift_y, axis=0) + regis_img_estimated = np.roll(regis_img_estimated, shift_x, axis=1) + # Assign estimated shift to results + shifts_k[m] = shift[::-1].astype(int) + + self._averaged_diagnostic( + class_k[0], + fixed_img, + class_k[m], + regis_img_estimated, + reflection_k[m], + r, + ) + else: + shift = None # For logger line + + # Estimated `corr` metric + corr = np.dot(fixed_img.flatten(), regis_img_estimated.flatten()) + correlations_k[m] = corr + + logger.debug( + f"ref {class_k[0]}, Neighbor {m} Index {class_k[m]}" + f" Estimates: {r}*, Shift: {shift}," + f" Corr: {corr}, Refl?: {reflection_k[m]}" + ) + + # Cleanup some cached stuff for this class + self.__cache.pop(id(warped_fixed_img_fs), None) + self.__cache.pop(id(twfixed_img), None) + + return rotations_k, shifts_k, correlations_k + + def average( + self, + classes, + reflections, + coefs=None, + ): + """ + This averages classes performing rotations then shifts. + Otherwise is similar to `AligningAverager2D.average`. + """ + + rotations, shifts, _ = self.align(classes, reflections, coefs) + + n_classes, n_nbor = classes.shape + + b_avgs = np.empty((n_classes, self.composite_basis.count), dtype=self.src.dtype) + + for i in tqdm(range(n_classes)): + + # Get coefs in Composite_Basis if not provided as an argument. + if coefs is None: + # Retrieve relavent images directly from source. + neighbors_imgs = Image(self._cls_images(classes[i])) + neighbors_coefs = self.composite_basis.evaluate_t(neighbors_imgs) + else: + # Get the neighbors + neighbors_ids = classes[i] + neighbors_coefs = coefs[neighbors_ids] + + # Rotate in composite_basis + neighbors_coefs = self.composite_basis.rotate( + neighbors_coefs, rotations[i], reflections[i] + ) + + # Note shifts are after rotation for this approach! + if shifts is not None: + neighbors_coefs = self.composite_basis.shift(neighbors_coefs, shifts[i]) + + # Averaging in composite_basis + b_avgs[i] = np.mean(neighbors_coefs, axis=0) + + # Now we convert the averaged images from Basis to Cartesian. + return ArrayImageSource(self.composite_basis.evaluate(b_avgs)) + + def _input_images_diagnostic(self, ia, a, ib, b): + if not self.diagnostics: + return + fig, axes = plt.subplots(1, 2) + ax = axes.ravel() + ax[0].set_title(f"Image {ia}") + ax[0].imshow(a) + ax[1].set_title(f"Image {ib}") + ax[1].imshow(b) + plt.show() + + def _windowed_psd_diagnostic(self, ia, a, ib, b): + if not self.diagnostics: + return + fig, axes = plt.subplots(1, 2) + ax = axes.ravel() + ax[0].set_title(f"Image {ia} PSD") + ax[0].imshow(np.log(a)) + ax[1].set_title(f"Image {ib} PSD") + ax[1].imshow(np.log(b)) + plt.show() + + def _log_polar_diagnostic(self, ia, a, ib, b): + if not self.diagnostics: + return + labels = np.arange(0, 360, 60) + y = labels / (360 / a.shape[0]) + + fig, axes = plt.subplots(1, 2) + ax = axes.ravel() + ax[0].set_title(f"Image {ia}") + ax[0].imshow(a) + ax[0].set_yticks(y, minor=False) + ax[0].set_yticklabels(labels) + ax[0].set_ylabel("Theta (Degrees)") + + ax[1].set_title(f"Image {ib}") + ax[1].imshow(b) + ax[1].set_yticks(y, minor=False) + ax[1].set_yticklabels(labels) + plt.show() + + def _rotation_cross_corr_diagnostic( + self, cross_correlation, cross_correlation_score + ): + if not self.diagnostics: + return + labels = [0, 30, 60, 90, -60, -30] + x = y = np.arange(0, 180, 30) / (180 / cross_correlation.shape[0]) + plt.title("Rotation Cross Correlation Map") + plt.imshow(cross_correlation) + plt.xlabel("Scale") + plt.yticks(y, labels, rotation="vertical") + plt.ylabel("Theta (Degrees)") + plt.show() + + plt.plot(cross_correlation_score) + plt.title("Angle vs Cross Correlation Score") + plt.xticks(x, labels) + plt.xlabel("Theta (Degrees)") + plt.ylabel("Cross Correlation Score") + plt.grid() + plt.show() + + def _rotated_diagnostic(self, ia, a, ib, b, sb, rb): + """ + Plot the image after estimated rotation and reflection. + + :param ia: index image `a` + :param a: image `a` + :param ib: index image `b` + :param b: image `b` after reflection `sb` and rotion `rb` + :param sb: Reflection, Boolean + :param rb: Estimated rotation, degrees + """ + + if not self.diagnostics: + return + + fig, axes = plt.subplots(1, 2) + ax = axes.ravel() + ax[0].set_title(f"Image {ia}") + ax[0].imshow(a) + ax[0].grid() + ax[1].set_title(f"Image {ib} Refl: {str(sb)[0]} Rotated {rb:.1f}") + ax[1].imshow(b) + ax[1].grid() + plt.show() + + def _translation_cross_corr_diagnostic(self, cross_correlation): + if not self.diagnostics: + return + plt.title("Translation Cross Correlation Map") + plt.imshow(cross_correlation) + plt.xlabel("x shift (pixels)") + plt.ylabel("y shift (pixels)") + L = self.alignment_src.L + labels = [0, 10, 20, 30, 0, -10, -20, -30] + tick_location = [0, 10, 20, 30, L, L - 10, L - 20, L - 30] + plt.xticks(tick_location, labels) + plt.yticks(tick_location, labels) + plt.show() + + def _averaged_diagnostic(self, ia, a, ib, b, sb, rb): + """ + Plot the stacked average image after + estimated rotation and reflections. + + Compare in a three way plot. + + :param ia: index image `a` + :param a: image `a` + :param ib: index image `b` + :param b: image `b` after reflection `sb` and rotion `rb` + :param sb: Reflection, Boolean + :param rb: Estimated rotation, degrees + """ + if not self.diagnostics: + return + fig, axes = plt.subplots(1, 3) + ax = axes.ravel() + ax[0].set_title(f"{ia}") + ax[0].imshow(a) + ax[0].grid() + ax[1].set_title(f"{ib} Refl: {str(sb)[0]} Rot: {rb:.1f}") + ax[1].imshow(b) + ax[1].grid() + ax[2].set_title("Stacked Avg") + plt.imshow((a + b) / 2.0) + ax[2].grid() + plt.show() + + +class BFSReddyChatterjiAverager2D(ReddyChatterjiAverager2D): + """ + Brute Force Shifts (Translations) - ReddyChatterji (Log-Polar) Rotations + + For each shift within `radius`, attempts rotational match using ReddyChatterji. + When averaging, performs shift before rotations, + + Adopted from Reddy Chatterji (1996) + An FFT-Based Technique for Translation, + Rotation, and Scale-Invariant Image Registration + IEEE TRANSACTIONS ON IMAGE PROCESSING, VOL. 5, NO. 8, AUGUST 1996 + + This method intentionally does not use any of ASPIRE's basis + so that it may be used as a reference for more ASPIRE approaches. + """ + + def __init__( + self, + composite_basis, + src, + alignment_src=None, + radius=None, + diagnostics=False, + dtype=None, + ): + """ + :param alignment_basis: Basis to be used during alignment. + For current implementation of ReddyChatterjiAverager2D this should be `None`. + Instead see `alignment_src`. + :param src: Source of original images. + :param composite_basis: Basis to be used during class average composition. + :param alignment_src: Optional, source to be used during class average alignment. + Must be the same resolution as `src`. + :param radius: Brute force translation search radius. + Defaults to src.L//8. + :param dtype: Numpy dtype to be used during alignment. + + :param diagnostics: Plot interactive diagnostic graphics (for debugging). + :param dtype: Numpy dtype to be used during alignment. + """ + + super().__init__( + composite_basis, + src, + alignment_src, + diagnostics, + dtype=dtype, + ) + + # For brute force we disable the cross_corr translation code + self.do_cross_corr_translations = False + # Assign search radius + self.radius = radius or src.L // 8 + + def align(self, classes, reflections, basis_coefficients): + """ + Performs the actual rotational alignment estimation, + returning parameters needed for averaging. + """ + + # Admit simple case of single case alignment + classes = np.atleast_2d(classes) + reflections = np.atleast_2d(reflections) + + n_classes, n_nbor = classes.shape + L = self.alignment_src.L + + # Instantiate matrices for inner loop, and best results. + _rotations = np.zeros(classes.shape, dtype=self.dtype) + rotations = np.zeros(classes.shape, dtype=self.dtype) + _correlations = np.zeros(classes.shape, dtype=self.dtype) + correlations = np.ones(classes.shape, dtype=self.dtype) * -np.inf + shifts = np.zeros((*classes.shape, 2), dtype=int) + + # We'll brute force all shifts in a grid. + g = grid_2d(L, normalized=False) + disc = g["r"] <= self.radius + X, Y = g["x"][disc], g["y"][disc] + + for k in trange(n_classes): + unshifted_images = self._cls_images(classes[k]) + + for xs, ys in zip(X, Y): + s = np.array([xs, ys]) + # Get the array of images for this class + + # Note we mutate `images` here with shifting, + # then later in `_reddychatterji` + images = unshifted_images.copy() + # Don't shift the base image + images[1:] = Image(unshifted_images[1:]).shift(s).asnumpy() + + rotations[k], _, correlations[k] = self._reddychatterji( + images, classes[k], reflections[k] + ) + + # Where corr has improved + # update our rolling best results with this loop. + improved = _correlations > correlations + correlations = np.where(improved, _correlations, correlations) + rotations = np.where(improved, _rotations, rotations) + shifts = np.where(improved[..., np.newaxis], s, shifts) + logger.debug(f"Shift {s} has improved {np.sum(improved)} results") + + return rotations, shifts, correlations + + def average( + self, + classes, + reflections, + coefs=None, + ): + """ + See Averager2D.average. + """ + # ReddyChatterjiAverager2D does rotations then shifts. + # For brute force, we'd like shifts then rotations, + # as is done in general in AligningAverager2D + return AligningAverager2D.average(self, classes, reflections, coefs) + + +class EMAverager2D(Averager2D): + """ + Citation needed. + """ + + +class FTKAverager2D(Averager2D): + """ + Factorization of the translation kernel for fast rigid image alignment. + Rangan, A.V., Spivak, M., Anden, J., & Barnett, A.H. (2019). + """ diff --git a/src/aspire/classification/class2d.py b/src/aspire/classification/class2d.py index 6544b78ffc..1542c12ac3 100644 --- a/src/aspire/classification/class2d.py +++ b/src/aspire/classification/class2d.py @@ -1,11 +1,12 @@ import logging +from abc import ABC, abstractmethod import numpy as np logger = logging.getLogger(__name__) -class Class2D: +class Class2D(ABC): """ Base class for 2D Image Classification methods. """ @@ -41,3 +42,17 @@ def __init__( self.n_nbor = n_nbor self.n_classes = n_classes self.seed = seed + + @abstractmethod + def classify(self): + """ + Classify the images from Source into classes with similar viewing angles. + + Returns classes and associated metadata (classes, reflections, distances) + """ + + @abstractmethod + def averages(self, classes, refl, distances): + """ + Returns class averages. + """ diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index a799ec7c27..82e9f9b6e1 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -6,11 +6,9 @@ from tqdm import tqdm from aspire.basis import FSPCABasis -from aspire.classification import BFRAlign2D, Class2D +from aspire.classification import BFSReddyChatterjiAverager2D, Class2D from aspire.classification.legacy_implementations import bispec_2drot_large, pca_y -from aspire.image import Image from aspire.numeric import ComplexPCA -from aspire.source import ArrayImageSource from aspire.utils.random import rand logger = logging.getLogger(__name__) @@ -31,7 +29,7 @@ def __init__( large_pca_implementation="legacy", nn_implementation="legacy", bispectrum_implementation="legacy", - aligner=None, + averager=None, dtype=None, seed=None, ): @@ -48,7 +46,8 @@ def __init__( Z. Zhao, Y. Shkolnisky, A. Singer, Rotationally Invariant Image Representation for Viewing Direction Classification in Cryo-EM. (2014) - :param src: Source instance + :param src: Source instance. Note it is possible to use one `source` for classification (ie CWF), + and a different `source` for stacking in the `averager`. :param pca_basis: Optional FSPCA Basis instance :param fspca_components: Components (top eigvals) to keep from full FSCPA, default truncates to 400. :param alpha: Amplitude Power Scale, default 1/3 (eq 20 from RIIR paper). @@ -60,7 +59,7 @@ def __init__( :param large_pca_implementation: See `pca`. :param nn_implementation: See `nn_classification`. :param bispectrum_implementation: See `bispectrum`. - :param aligner: An Align2D subclass. Defaults to BFRAlign2D. + :param averager: An Averager2D subclass. Defaults to BFSReddyChatterjiAverager2D. :param dtype: Optional dtype, otherwise taken from src. :param seed: Optional RNG seed to be passed to random methods, (example Random NN). :return: RIRClass2D instance to be used to compute bispectrum-like rotationally invariant 2D classification. @@ -101,7 +100,7 @@ def __init__( self.alpha = alpha self.bispectrum_components = bispectrum_components self.bispectrum_freq_cutoff = bispectrum_freq_cutoff - self.aligner = aligner + self.averager = averager if self.src.n < self.bispectrum_components: raise RuntimeError( @@ -165,16 +164,17 @@ def classify(self, diagnostics=False): # default of 400 components was taken from legacy code. # Instantiate a new compressed (truncated) basis. if self.pca_basis is None: - # self.pca_basis = self.pca_basis.compress(self.fspca_components) self.pca_basis = FSPCABasis(self.src, components=self.fspca_components) # For convenience, assign the fb_basis used in the pca_basis. self.fb_basis = self.pca_basis.basis - # When not provided by a user, the aligner is instantiated after + # When not provided by a user, the averager is instantiated after # we are certain our pca_basis has been constructed. - if self.aligner is None: - self.aligner = BFRAlign2D(self.pca_basis, dtype=self.dtype) + if self.averager is None: + self.averager = BFSReddyChatterjiAverager2D( + self.fb_basis, self.src, dtype=self.dtype + ) # Get the expanded coefs in the compressed FSPCA space. self.fspca_coef = self.pca_basis.spca_coef @@ -184,7 +184,7 @@ def classify(self, diagnostics=False): # # Stage 2: Compute Nearest Neighbors logger.info("Calculate Nearest Neighbors") - classes, refl, distances = self.nn_classification(coef_b, coef_b_r) + classes, reflections, distances = self.nn_classification(coef_b, coef_b_r) if diagnostics: # Lets peek at the distribution of distances @@ -193,24 +193,39 @@ def classify(self, diagnostics=False): plt.show() # Report some information about reflections - logger.info(f"Count reflected: {np.sum(refl)}" f" {100 * np.mean(refl) } %") + logger.info( + f"Count reflected: {np.sum(reflections)}" + f" {100 * np.mean(reflections) } %" + ) + + return classes, reflections, distances + def averages(self, classes, reflections, distances): # # Stage 3: Class Selection - logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") # This is an area open to active research. # Currently we take a naive approach by selecting the # first n_classes assuming they are quasi random. - classes = classes[: self.n_classes] + logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") + classes, reflections = self.select_classes(classes, reflections) - # # Stage 4: Align + # # Stage 4: Averager logger.info( - f"Begin Rotational Alignment of {classes.shape[0]} Classes using {self.aligner}." + f"Begin Averaging of {classes.shape[0]} Classes using {self.averager}." ) - if not self.aligner.basis == self.pca_basis: - raise RuntimeError( - f"Aligner {self.aligner} basis does not match FSPCA basis." - ) - return self.aligner.align(classes, refl, self.fspca_coef) + + return self.averager.average(classes, reflections) + + def select_classes(self, classes, reflections): + """ + Select the `n_classes` to average from the (n_images) population of classes. + """ + # Generate indices for random sample (can do something smarter, or build this out later). + # For testing/poc just take the first n_classes so it matches earlier plots for manual comparison + # If image stack is assumed to be reasonably random, this is a reasonable thing to do. + # Another reasonable thing would be to take a random selection over the whole dataset, + # in case the head of a dataset is too similar or has artifacts. + selection = np.arange(self.n_classes) + return classes[selection], reflections[selection] def pca(self, M): """ @@ -298,64 +313,9 @@ def _sk_nn_classification(self, coeff_b, coeff_b_r): return indices - def output( - self, - classes, - classes_refl, - rot, - shifts=None, - coefs=None, - ): - """ - Return class averages. - - :param classes: class indices (refering to src). (n_img, n_nbor) - :param classes_refl: Bool representing whether to reflect image in `classes` - :param rot: Array of in-plane rotation angles (Radians) of image in `classes` - :param shifts: Optional array of shifts for image in `classes`. - :coefs: Optional Fourier bessel coefs (avoids recomputing). - :return: Stack of Synthetic Class Average images as Image instance. - """ - - logger.info(f"Select {self.n_classes} Classes from Nearest Neighbors") - # generate indices for random sample (can do something smart with corr later). - # For testing just take the first n_classes so it matches earlier plots for manual comparison - # This is assumed to be reasonably random. - selection = np.arange(self.n_classes) - - imgs = self.src.images(0, self.src.n) - fb_avgs = np.empty((self.n_classes, self.fb_basis.count), dtype=self.src.dtype) - - for i in tqdm(range(self.n_classes)): - j = selection[i] - # Get the neighbors - neighbors_ids = classes[j] - - # Get coefs in Fourier Bessel Basis if not provided as an argument. - if coefs is None: - neighbors_imgs = Image(imgs[neighbors_ids]) - if shifts is not None: - neighbors_imgs.shift(shifts[i]) - neighbors_coefs = self.fb_basis.evaluate_t(neighbors_imgs) - else: - neighbors_coefs = coefs[neighbors_ids] - if shifts is not None: - neighbors_coefs = self.fb_basis.shift(neighbors_coefs, shifts[i]) - - # Rotate in Fourier Bessel - neighbors_coefs = self.fb_basis.rotate( - neighbors_coefs, rot[j], classes_refl[j] - ) - - # Averaging in FB - fb_avgs[i] = np.mean(neighbors_coefs, axis=0) - - # Now we convert the averaged images from FB to Cartesian. - return ArrayImageSource(self.fb_basis.evaluate(fb_avgs)) - def _legacy_nn_classification(self, coeff_b, coeff_b_r, batch_size=2000): """ - Perform nearest neighbor classification and alignment. + Perform nearest neighbor classification. """ # Note kept ordering from legacy code (n_features, n_img) @@ -390,7 +350,7 @@ def _legacy_nn_classification(self, coeff_b, coeff_b_r, batch_size=2000): # Check with Joakim about preference. # I (GBW) think class[i] should have class[i][0] be the original image index. classes[start:finish] = np.argsort(-corr, axis=1)[:, :n_nbor] - # Store the corr values for the n_nhors in this batch + # Store the corr values for the n_nbors in this batch distances[start:finish] = np.take_along_axis( corr, classes[start:finish], axis=1 ) diff --git a/src/aspire/covariance/covar.py b/src/aspire/covariance/covar.py index 181b6840e6..397dcb5366 100644 --- a/src/aspire/covariance/covar.py +++ b/src/aspire/covariance/covar.py @@ -61,7 +61,7 @@ def compute_kernel(self): weights[:, 0, :] = 0 # TODO: This is where this differs from MeanEstimator - pts_rot = np.moveaxis(pts_rot, -1, 0).reshape(-1, 3, L**2) + pts_rot = np.moveaxis(pts_rot[::-1], 1, 0).reshape(-1, 3, L**2) weights = weights.T.reshape((-1, L**2)) batch_n = weights.shape[0] diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index 95cb734150..b8599f27fb 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -284,7 +284,7 @@ def identity(x): for k in np.unique(ctf_idx[:]): - coeff_k = coeffs[ctf_idx == k] + coeff_k = coeffs[ctf_idx == k].astype(self.dtype) weight = coeff_k.shape[0] / coeffs.shape[0] ctf_fb_k = ctf_fb[k] diff --git a/src/aspire/denoising/__init__.py b/src/aspire/denoising/__init__.py index 30aaf0ccb0..492f1f6aec 100644 --- a/src/aspire/denoising/__init__.py +++ b/src/aspire/denoising/__init__.py @@ -1,3 +1,4 @@ from .adaptive_support import adaptive_support +from .denoised_src import DenoisedImageSource from .denoiser import Denoiser -from .denoiser_cov2d import src_wiener_coords +from .denoiser_cov2d import DenoiserCov2D, src_wiener_coords diff --git a/src/aspire/denoising/denoised_src.py b/src/aspire/denoising/denoised_src.py index db061a0c81..83e2ec3162 100644 --- a/src/aspire/denoising/denoised_src.py +++ b/src/aspire/denoising/denoised_src.py @@ -45,6 +45,9 @@ def _images(self, start=0, num=np.inf, indices=None, batch_size=512): nimgs = len(indices) im = np.empty((nimgs, self.L, self.L)) + # If we request less than a whole batch, don't crash + batch_size = min(nimgs, batch_size) + logger.info(f"Loading {nimgs} images complete") for batch_start in range(start, end + 1, batch_size): imgs_denoised = self.denoiser.images(batch_start, batch_size) diff --git a/src/aspire/denoising/denoiser_cov2d.py b/src/aspire/denoising/denoiser_cov2d.py index 71ef7e97e5..9e3ac9554a 100644 --- a/src/aspire/denoising/denoiser_cov2d.py +++ b/src/aspire/denoising/denoiser_cov2d.py @@ -103,7 +103,7 @@ class DenoiserCov2D(Denoiser): Define a derived class for denoising 2D images using Cov2D method """ - def __init__(self, src, basis, var_noise=None): + def __init__(self, src, basis=None, var_noise=None): """ Initialize an object for denoising 2D images using Cov2D method @@ -122,8 +122,12 @@ def __init__(self, src, basis, var_noise=None): logger.info(f"Estimated Noise Variance: {var_noise}") self.var_noise = var_noise + if basis is None: + basis = FFBBasis2D((self.src.L, self.src.L), dtype=src.dtype) + if not isinstance(basis, FFBBasis2D): raise NotImplementedError("Currently only fast FB method is supported") + self.basis = basis self.cov2d = None self.mean_est = None diff --git a/tests/test_align2d.py b/tests/test_averager2d.py similarity index 61% rename from tests/test_align2d.py rename to tests/test_averager2d.py index e4f3bc9085..10b46e5ac5 100644 --- a/tests/test_align2d.py +++ b/tests/test_averager2d.py @@ -6,7 +6,13 @@ import pytest from aspire.basis import DiracBasis, FFBBasis2D -from aspire.classification import Align2D, BFRAlign2D, BFSRAlign2D +from aspire.classification import ( + Averager2D, + BFRAverager2D, + BFSRAverager2D, + BFSReddyChatterjiAverager2D, + ReddyChatterjiAverager2D, +) from aspire.source import Simulation from aspire.utils import Rotation from aspire.volume import Volume @@ -19,9 +25,9 @@ # Ignore Gimbal lock warning for our in plane rotations. @pytest.mark.filterwarnings("ignore:Gimbal lock detected") -class Align2DTestCase(TestCase): - # Subclasses should override `aligner` with a different class. - aligner = Align2D +class Averager2DTestCase(TestCase): + # Subclasses should override `averager` with a different class. + averager = Averager2D def setUp(self): @@ -33,7 +39,7 @@ def setUp(self): self.n_img = 3 self.dtype = np.float64 - # Create a Basis to use in alignment. + # Create a Basis to use in averager. self.basis = FFBBasis2D((self.resolution, self.resolution), dtype=self.dtype) # This sets up a trivial class, where there is one group having all images. @@ -48,17 +54,25 @@ def inject_fixtures(self, caplog): def tearDown(self): pass + def _getSrc(self): + # Base Averager2D does not require anything from source. + # Subclasses implement specific src + return None + def testTypeMismatch(self): - # Intentionally mismatch Basis and Aligner dtypes + # Work around ABC, which won't let us test the unimplemented base case. + self.averager.__abstractmethods__ = set() + + # Intentionally mismatch Basis and Averager dtypes if self.dtype == np.float32: test_dtype = np.float64 else: test_dtype = np.float32 with self._caplog.at_level(logging.WARN): - self.aligner(self.basis, dtype=test_dtype) - assert " does not match self.dtype" in self._caplog.text + self.averager(self.basis, self._getSrc(), dtype=test_dtype) + assert "does not match dtype" in self._caplog.text def _construct_rotations(self): """ @@ -91,9 +105,10 @@ def r(theta): self.rots = Rotation.from_matrix(_rots) -class BFRAlign2DTestCase(Align2DTestCase): +@pytest.mark.filterwarnings("ignore:Gimbal lock detected") +class BFRAverager2DTestCase(Averager2DTestCase): - aligner = BFRAlign2D + averager = BFRAverager2D def setUp(self): @@ -135,22 +150,19 @@ def testNoRot(self): # and that should raise an error during instantiation. with pytest.raises(RuntimeError, match=r".* must provide a `rotate` method."): - _ = self.aligner(basis) + _ = self.averager(basis, self._getSrc()) - def testAlign(self): + def testAverager(self): """ Construct a stack of images with known rotations. - Rotationally align the stack and compare output with known rotations. + Rotationally averager the stack and compare output with known rotations. """ - # Construction the Aligner and then call the main `align` method - _classes, _reflections, _rotations, _shifts, _ = self.aligner( - self.basis, n_angles=self.n_search_angles - ).align(self.classes, self.reflections, self.coefs) + # Construct the Averager and then call the `align` method + avgr = self.averager(self.basis, self._getSrc(), n_angles=self.n_search_angles) + _rotations, _shifts, _ = avgr.align(self.classes, self.reflections, self.coefs) - self.assertTrue(np.all(_classes == self.classes)) - self.assertTrue(np.all(_reflections == self.reflections)) self.assertIsNone(_shifts) # Crude check that we are closer to known angle than the next rotation @@ -162,20 +174,21 @@ def testAlign(self): ) -class BFSRAlign2DTestCase(BFRAlign2DTestCase): +@pytest.mark.filterwarnings("ignore:Gimbal lock detected") +class BFSRAverager2DTestCase(BFRAverager2DTestCase): - aligner = BFSRAlign2D + averager = BFSRAverager2D def setUp(self): # Inherit basic params from the base class - super(BFRAlign2DTestCase, self).setUp() + super(BFRAverager2DTestCase, self).setUp() # Setup shifts, don't shift the base image self.shifts = np.zeros((self.n_img, 2)) self.shifts[1:, 0] = 2 self.shifts[1:, 1] = 4 - # Execute the remaining setup from BFRAlign2DTestCase + # Execute the remaining setup from BFRAverager2DTestCase super().setUp() def testNoShift(self): @@ -192,22 +205,24 @@ def testNoShift(self): # and that should raise an error during instantiation. with pytest.raises(RuntimeError, match=r".* must provide a `shift` method."): - _ = self.aligner(basis) + _ = self.averager(basis, self._getSrc()) - def testAlign(self): + def testAverager(self): """ Construct a stack of images with known rotations. - Rotationally align the stack and compare output with known rotations. + Rotationally averager the stack and compare output with known rotations. """ - # Construction the Aligner and then call the main `align` method - _classes, _reflections, _rotations, _shifts, _ = self.aligner( - self.basis, n_angles=self.n_search_angles, n_x_shifts=1, n_y_shifts=1 - ).align(self.classes, self.reflections, self.coefs) - - self.assertTrue(np.all(_classes == self.classes)) - self.assertTrue(np.all(_reflections == self.reflections)) + # Construct the Averager and then call the main `align` method + avgr = self.averager( + self.basis, + self._getSrc(), + n_angles=self.n_search_angles, + n_x_shifts=1, + n_y_shifts=1, + ) + _rotations, _shifts, _ = avgr.align(self.classes, self.reflections, self.coefs) # Crude check that we are closer to known angle than the next rotation self.assertTrue(np.all((_rotations - self.thetas) <= (self.step / 2))) @@ -225,3 +240,45 @@ def testAlign(self): # non zero shift+rot improved corr. # Perhaps in the future should check more details. self.assertTrue(np.all(np.hypot(*_shifts[0][1:].T) >= 1)) + + +@pytest.mark.filterwarnings("ignore:Gimbal lock detected") +class ReddyChatterjiAverager2DTestCase(BFSRAverager2DTestCase): + + averager = ReddyChatterjiAverager2D + + def testAverager(self): + """ + Construct a stack of images with known rotations. + + Rotationally averager the stack and compare output with known rotations. + """ + + # Construct the Averager and then call the main `align` method + avgr = self.averager( + composite_basis=self.basis, + src=self._getSrc(), + dtype=self.dtype, + ) + _rotations, _shifts, _ = avgr.align(self.classes, self.reflections, self.coefs) + + # Crude check that we are closer to known angle than the next rotation + self.assertTrue(np.all((_rotations - self.thetas) <= (self.step / 2))) + + # Fine check that we are within one degree. + self.assertTrue(np.all((_rotations - self.thetas) <= (2 * np.pi / 360.0))) + + # Check that we are _not_ shifting the base image + self.assertTrue(np.all(_shifts[0][0] == 0)) + # Check that we produced estimated shifts away from origin + # Note that Simulation's rot+shift is generally not equal to shift+rot. + # Instead we check that some combination of + # non zero shift+rot improved corr. + # Perhaps in the future should check more details. + self.assertTrue(np.all(np.hypot(*_shifts[0][1:].T) >= 1)) + + +@pytest.mark.filterwarnings("ignore:Gimbal lock detected") +class BFSReddyChatterjiAverager2DTestCase(ReddyChatterjiAverager2DTestCase): + + averager = BFSReddyChatterjiAverager2D diff --git a/tests/test_class2D.py b/tests/test_class2D.py index 0dac11e5a6..07935f217b 100644 --- a/tests/test_class2D.py +++ b/tests/test_class2D.py @@ -7,7 +7,7 @@ from sklearn import datasets from aspire.basis import FFBBasis2D, FSPCABasis -from aspire.classification import BFSRAlign2D, Class2D, RIRClass2D +from aspire.classification import BFRAverager2D, RIRClass2D from aspire.classification.legacy_implementations import bispec_2drot_large, pca_y from aspire.operators import ScalarFilter from aspire.source import Simulation @@ -40,6 +40,7 @@ def setUp(self): vols=v, dtype=self.dtype, ) + self.src.cache() # Precompute image stack # Calculate some projection images self.imgs = self.src.images(0, self.src.n) @@ -139,14 +140,6 @@ def setUp(self): # Ceate another fspca_basis, use autogeneration FFB2D Basis self.noisy_fspca_basis = FSPCABasis(self.noisy_src) - def testClass2DBase(self): - """ - Make sure the base class doesn't crash when using arguments. - """ - _ = Class2D(self.clean_src) # Default dtype - _ = Class2D(self.clean_src, dtype=self.dtype) # Consistent dtype - _ = Class2D(self.clean_src, dtype=np.float16) # Different dtype - def testIncorrectComponents(self): """ Check we raise with inconsistent configuration of FSPCA components. @@ -191,8 +184,8 @@ def testRIRLegacy(self): bispectrum_implementation="legacy", ) - result = rir.classify() - _ = rir.output(*result[:3]) + classification_results = rir.classify() + _ = rir.averages(*classification_results) def testRIRDevelBisp(self): """ @@ -207,8 +200,8 @@ def testRIRDevelBisp(self): bispectrum_implementation="devel", ) - result = rir.classify() - _ = rir.output(*result[:3]) + classification_results = rir.classify() + _ = rir.averages(*classification_results) def testRIRsk(self): """ @@ -226,11 +219,15 @@ def testRIRsk(self): large_pca_implementation="sklearn", nn_implementation="sklearn", bispectrum_implementation="devel", - aligner=BFSRAlign2D(self.noisy_fspca_basis, n_angles=100, n_x_shifts=0), + averager=BFRAverager2D( + self.noisy_fspca_basis.basis, # FFB basis + self.noisy_src, + n_angles=100, + ), ) - result = rir.classify() - _ = rir.output(*result[:4]) + classification_results = rir.classify() + _ = rir.averages(*classification_results) def testEigenImages(self): """