1919import numpy as np
2020
2121from aspire .abinitio import CLSyncVoting
22- from aspire .basis import FFBBasis3D
23- from aspire .classification import RIRClass2D
22+ from aspire .basis import FFBBasis2D , FFBBasis3D
23+ from aspire .classification import BFSReddyChatterjiAlign2D , RIRClass2D
2424from aspire .denoising import DenoiserCov2D
2525from aspire .noise import AnisotropicNoiseEstimator
2626from aspire .operators import FunctionFilter , RadialCTFFilter
@@ -132,22 +132,32 @@ def noise_function(x, y):
132132if interactive :
133133 src .images (0 , 10 ).show ()
134134
135+ # # Optionally invert image contrast, depends on data.
135136# logger.info("Invert the global density contrast")
136137# src.invert_contrast()
137138
138- # # On Simulation data, better results so far were achieved without cov2d.
139+ # Cache to memory for some speedup
140+ src = ArrayImageSource (src .images (0 , num_imgs ).asnumpy (), angles = src .angles )
141+
142+ # On Simulation data, better results so far were achieved without cov2d
143+ # However, we can demonstrate using CWF denoised images for classification.
144+ classification_src = src
145+ custom_aligner = None
139146if do_cov2d :
140147 # Use CWF denoising
141148 cwf_denoiser = DenoiserCov2D (src )
142- src = cwf_denoiser . denoise ()
143-
144- # Peek, what do the denoised images look like...
145- if interactive :
146- src .images (0 , 10 ).show ()
149+ # Use denoised src for classification
150+ classification_src = cwf_denoiser . denoise ()
151+ # Peek, what do the denoised images look like...
152+ if interactive :
153+ classification_src .images (0 , 10 ).show ()
147154
155+ # Use regular `src` for the alignment and composition (averaging).
156+ composite_basis = FFBBasis2D ((src .L ,) * 2 , dtype = src .dtype )
157+ custom_aligner = BFSReddyChatterjiAlign2D (
158+ None , src , composite_basis , dtype = src .dtype
159+ )
148160
149- # Cache to memory for some speedup
150- src = ArrayImageSource (src .images (0 , num_imgs ).asnumpy (), angles = src .angles )
151161
152162# %%
153163# Class Averaging
@@ -158,14 +168,15 @@ def noise_function(x, y):
158168logger .info ("Begin Class Averaging" )
159169
160170rir = RIRClass2D (
161- src ,
171+ classification_src , # Source used for classification
162172 fspca_components = 400 ,
163173 bispectrum_components = 300 , # Compressed Features after last PCA stage.
164174 n_nbor = n_nbor ,
165175 n_classes = n_classes ,
166176 large_pca_implementation = "legacy" ,
167177 nn_implementation = "sklearn" ,
168178 bispectrum_implementation = "legacy" ,
179+ aligner = custom_aligner ,
169180)
170181
171182classes , reflections , distances = rir .classify ()
0 commit comments