diff --git a/src/aspire/commands/extract_particles.py b/src/aspire/commands/extract_particles.py index 572691d849..10e801121a 100644 --- a/src/aspire/commands/extract_particles.py +++ b/src/aspire/commands/extract_particles.py @@ -5,6 +5,7 @@ import click from click import UsageError +from aspire.noise import WhiteNoiseEstimator from aspire.source.coordinates import BoxesCoordinateSource, CentersCoordinateSource logger = logging.getLogger(__name__) @@ -37,6 +38,27 @@ is_flag=True, help="Set this flag if coordinate files contain (X,Y) particle centers", ) +@click.option( + "--downsample", + default=0, + type=int, + help="Downsample the images to this resolution prior to saving to starfile/.mrcs stack", +) +@click.option( + "--normalize_bg", + is_flag=True, + help="Normalize the images to have mean zero and variance one in the corners", +) +@click.option( + "--whiten", + is_flag=True, + help="Estimate the noise variance of the images and whiten", +) +@click.option( + "--invert_contrast", + is_flag=True, + help="Invert the contrast of the images to ensure that clean particles have positive intensity", +) @click.option( "--batch_size", default=512, help="Batch size to load images from .mrc files" ) @@ -54,6 +76,10 @@ def extract_particles( starfile_out, particle_size, centers, + downsample, + normalize_bg, + whiten, + invert_contrast, batch_size, save_mode, overwrite, @@ -109,6 +135,17 @@ def extract_particles( particle_size=particle_size, ) + # optional preprocessing steps + if 0 < downsample < src.L: + src.downsample(downsample) + if normalize_bg: + src.normalize_background() + if whiten: + estimator = WhiteNoiseEstimator(src) + src.whiten(estimator.filter) + if invert_contrast: + src.invert_contrast() + # saves to .mrcs and STAR file with column "_rlnImageName" src.save( starfile_out, batch_size=batch_size, save_mode=save_mode, overwrite=overwrite diff --git a/src/aspire/commands/preprocess.py b/src/aspire/commands/preprocess.py index b7ef198c1a..e91208a6d7 100644 --- a/src/aspire/commands/preprocess.py +++ b/src/aspire/commands/preprocess.py @@ -31,24 +31,28 @@ ) @click.option("--flip_phase", default=True, help="Perform phase flip or not") @click.option( - "--max_resolution", - default=16, + "--downsample", + default=0, type=int, - help="Resolution for downsampling images read from STAR file", + help="Downsample the images to this resolution prior to saving to starfile/.mrcs stack", ) @click.option( - "--normalize_background", + "--normalize_bg", default=True, - help="Whether to normalize images to background noise", + help="Normalize the images to have mean zero and variance one in the corners", +) +@click.option( + "--whiten", + default=True, + help="Estimate the noise variance of the images and whiten", ) -@click.option("--whiten_noise", default=True, help="Whiten background noise") @click.option( "--invert_contrast", default=True, - help="Invert the contrast of images so molecules are shown in white", + help="Invert the contrast of the images to ensure that clean particles have positive intensity", ) @click.option( - "--batch_size", default=512, help="Batch size to load images from MRC files." + "--batch_size", default=512, help="Batch size to load images from MRC files" ) @click.option( "--save_mode", @@ -67,9 +71,9 @@ def preprocess( pixel_size, max_rows, flip_phase, - max_resolution, - normalize_background, - whiten_noise, + downsample, + normalize_bg, + whiten, invert_contrast, batch_size, save_mode, @@ -88,15 +92,15 @@ def preprocess( logger.info("Perform phase flip to input images") source.phase_flip() - if max_resolution < source.L: - logger.info(f"Downsample resolution to {max_resolution} X {max_resolution}") - source.downsample(max_resolution) + if 0 < downsample < source.L: + logger.info(f"Downsample resolution to {downsample} X {downsample}") + source.downsample(downsample) - if normalize_background: + if normalize_bg: logger.info("Normalize images to noise background") source.normalize_background() - if whiten_noise: + if whiten: logger.info("Whiten noise of images") noise_estimator = WhiteNoiseEstimator(source) source.whiten(noise_estimator.filter) diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index a1fc26d433..9050e65e46 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -343,7 +343,20 @@ def testCommand(self): "--particle_size=256", ], ) + result_preprocess = runner.invoke( + extract_particles, + [ + f"--mrc_paths={self.data_folder}/*.mrc", + f"--coord_paths={self.data_folder}/sample*.box", + f"--starfile_out={self.data_folder}/saved_star_ds.star", + "--downsample=33", + "--normalize_bg", + "--whiten", + "--invert_contrast", + ], + ) # check that all commands completed successfully self.assertTrue(result_box.exit_code == 0) self.assertTrue(result_coord.exit_code == 0) self.assertTrue(result_star.exit_code == 0) + self.assertTrue(result_preprocess.exit_code == 0)