Skip to content

Commit d7df6e3

Browse files
committed
Fixup the CWF experiment example (some dtypes issues)
1 parent 3d54c22 commit d7df6e3

File tree

4 files changed

+26
-19
lines changed

4 files changed

+26
-19
lines changed

gallery/experiments/simulated_abinitio_pipeline.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import numpy as np
2020

2121
from aspire.abinitio import CLSyncVoting
22-
from aspire.basis import FFBBasis3D
23-
from aspire.classification import RIRClass2D
22+
from aspire.basis import FFBBasis2D, FFBBasis3D
23+
from aspire.classification import BFSReddyChatterjiAlign2D, RIRClass2D
2424
from aspire.denoising import DenoiserCov2D
2525
from aspire.noise import AnisotropicNoiseEstimator
2626
from aspire.operators import FunctionFilter, RadialCTFFilter
@@ -132,22 +132,32 @@ def noise_function(x, y):
132132
if interactive:
133133
src.images(0, 10).show()
134134

135+
# # Optionally invert image contrast, depends on data.
135136
# logger.info("Invert the global density contrast")
136137
# src.invert_contrast()
137138

138-
# # On Simulation data, better results so far were achieved without cov2d.
139+
# Cache to memory for some speedup
140+
src = ArrayImageSource(src.images(0, num_imgs).asnumpy(), angles=src.angles)
141+
142+
# On Simulation data, better results so far were achieved without cov2d
143+
# However, we can demonstrate using CWF denoised images for classification.
144+
classification_src = src
145+
custom_aligner = None
139146
if do_cov2d:
140147
# Use CWF denoising
141148
cwf_denoiser = DenoiserCov2D(src)
142-
src = cwf_denoiser.denoise()
143-
144-
# Peek, what do the denoised images look like...
145-
if interactive:
146-
src.images(0, 10).show()
149+
# Use denoised src for classification
150+
classification_src = cwf_denoiser.denoise()
151+
# Peek, what do the denoised images look like...
152+
if interactive:
153+
classification_src.images(0, 10).show()
147154

155+
# Use regular `src` for the alignment and composition (averaging).
156+
composite_basis = FFBBasis2D((src.L,) * 2, dtype=src.dtype)
157+
custom_aligner = BFSReddyChatterjiAlign2D(
158+
None, src, composite_basis, dtype=src.dtype
159+
)
148160

149-
# Cache to memory for some speedup
150-
src = ArrayImageSource(src.images(0, num_imgs).asnumpy(), angles=src.angles)
151161

152162
# %%
153163
# Class Averaging
@@ -158,14 +168,15 @@ def noise_function(x, y):
158168
logger.info("Begin Class Averaging")
159169

160170
rir = RIRClass2D(
161-
src,
171+
classification_src, # Source used for classification
162172
fspca_components=400,
163173
bispectrum_components=300, # Compressed Features after last PCA stage.
164174
n_nbor=n_nbor,
165175
n_classes=n_classes,
166176
large_pca_implementation="legacy",
167177
nn_implementation="sklearn",
168178
bispectrum_implementation="legacy",
179+
aligner=custom_aligner,
169180
)
170181

171182
classes, reflections, distances = rir.classify()

src/aspire/classification/align2d.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,7 @@ def __init__(
421421
Instead see `alignment_source`.
422422
:param source: Source of original images.
423423
:param composite_basis: Basis to be used during class average composition.
424-
For current implementation of ReddyChatterjiAlign2D this should be `None`.
425-
Instead this method uses `source` for composition of the averaged stack.
426-
:param alignment_source: Basis to be used during class average composition.
424+
:param alignment_source: Optional, source to be used during class average alignment.
427425
Must be the same resolution as `source`.
428426
:param dtype: Numpy dtype to be used during alignment.
429427
"""
@@ -888,9 +886,7 @@ def __init__(
888886
Instead see `alignment_source`.
889887
:param source: Source of original images.
890888
:param composite_basis: Basis to be used during class average composition.
891-
For current implementation of ReddyChatterjiAlign2D this should be `None`.
892-
Instead this method uses `source` for composition of the averaged stack.
893-
:param alignment_source: Basis to be used during class average composition.
889+
:param alignment_source: Optional, source to be used during class average alignment.
894890
Must be the same resolution as `source`.
895891
:param radius: Brute force translation search radius.
896892
Defaults to source.L//8.

src/aspire/covariance/covar2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def identity(x):
284284

285285
for k in np.unique(ctf_idx[:]):
286286

287-
coeff_k = coeffs[ctf_idx == k]
287+
coeff_k = coeffs[ctf_idx == k].astype(self.dtype)
288288
weight = coeff_k.shape[0] / coeffs.shape[0]
289289

290290
ctf_fb_k = ctf_fb[k]

src/aspire/denoising/denoiser_cov2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(self, src, basis=None, var_noise=None):
123123
self.var_noise = var_noise
124124

125125
if basis is None:
126-
basis = FFBBasis2D((self.src.L, self.src.L))
126+
basis = FFBBasis2D((self.src.L, self.src.L), dtype=src.dtype)
127127

128128
if not isinstance(basis, FFBBasis2D):
129129
raise NotImplementedError("Currently only fast FB method is supported")

0 commit comments

Comments
 (0)