diff --git a/gallery/experiments/save_simulation_relion_reconstruct.py b/gallery/experiments/save_simulation_relion_reconstruct.py new file mode 100644 index 000000000..767abf6bf --- /dev/null +++ b/gallery/experiments/save_simulation_relion_reconstruct.py @@ -0,0 +1,86 @@ +""" +Simulated Stack → RELION Reconstruction +======================================= + +This experiment shows how to: + +1. build a synthetic dataset with ASPIRE, +2. write the stack via ``ImageSource.save`` so RELION can consume it, and +3. call :code:`relion_reconstruct` on the saved STAR file. +""" + +# %% +# Imports +# ------- + +import logging +from pathlib import Path + +import numpy as np + +from aspire.downloader import emdb_2660 +from aspire.noise import WhiteNoiseAdder +from aspire.operators import RadialCTFFilter +from aspire.source import Simulation + +logger = logging.getLogger(__name__) + + +# %% +# Configuration +# ------------- +# We set a few parameters to initialize the Simulation. +# You can safely alter ``n_particles`` (or change the voltages, etc.) when +# trying this interactively; the defaults here are chosen for demonstrative purposes. + +output_dir = Path("relion_save_demo") +output_dir.mkdir(exist_ok=True) + +n_particles = 512 +snr = 0.25 +voltages = np.linspace(200, 300, 3) # kV settings for the radial CTF filters +star_path = output_dir / f"sim_n{n_particles}.star" + + +# %% +# Volume and Filters +# ------------------ +# Start from the EMDB-2660 ribosome map and build a small set of radial CTF filters +# that RELION will recover as optics groups. + +vol = emdb_2660() +ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + + +# %% +# Simulate, Add Noise, Save +# ------------------------- +# Initialize the Simulation: +# mix the CTFs across the stack, add white noise at a target SNR, +# and write the particles and metadata to a RELION-compatible STAR/MRC stack. + +sim = Simulation( + n=n_particles, + vols=vol, + unique_filters=ctf_filters, + noise_adder=WhiteNoiseAdder.from_snr(snr), +) +sim.save(star_path, overwrite=True) + + +# %% +# Running ``relion_reconstruct`` +# ------------------------------ +# ``relion_reconstruct`` is an external RELION command, so we just show the call. +# Run this in a RELION-enabled shell after generating the STAR file above. + +relion_cmd = [ + "relion_reconstruct", + "--i", + str(star_path), + "--o", + str(output_dir / "relion_recon.mrc"), + "--ctf", +] + +logger.info(" ".join(relion_cmd)) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index f1dff7440..aa1b85a1b 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -3,7 +3,7 @@ import logging import os.path from abc import ABC, abstractmethod -from collections import OrderedDict +from collections import OrderedDict, defaultdict from collections.abc import Iterable import mrcfile @@ -483,15 +483,16 @@ def offsets(self, values): @property def amplitudes(self): - return np.atleast_1d( - self.get_metadata( - "_rlnAmplitude", default_value=np.array(1.0, dtype=self.dtype) - ) + values = self.get_metadata( + "_aspireAmplitude", + default_value=np.array(1.0, dtype=np.float64), ) + return np.atleast_1d(np.asarray(values, dtype=np.float64)) @amplitudes.setter def amplitudes(self, values): - return self.set_metadata("_rlnAmplitude", np.array(values, dtype=self.dtype)) + values = np.asarray(values, dtype=np.float64) + self.set_metadata("_aspireAmplitude", values) @property def angles(self): @@ -1289,6 +1290,86 @@ def _populate_local_metadata(self): """ return [] + @staticmethod + def _prepare_relion_optics_blocks(metadata): + """ + Split metadata into RELION>=3.1 style `data_optics` and `data_particles` blocks. + + The optics block has one row per optics group with: + `_rlnOpticsGroup`, `_rlnOpticsGroupName`, and optics metadata columns. + The particle block keeps the remaining columns and includes a per-particle + `_rlnOpticsGroup` that references the optics block. + """ + # Columns that belong in RELION's optics table. + all_optics_fields = [ + "_rlnImagePixelSize", + "_rlnMicrographPixelSize", + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnAmplitudeContrast", + "_rlnImageSize", + "_rlnImageDimensionality", + ] + + # Some optics group fields might not always be present, but are necessary + # for reading the file in Relion. We ensure these fields exist and populate + # with a dummy value if not. + n_rows = len(metadata["_rlnImageName"]) + + missing_fields = [] + + def _ensure_column(field, value): + if field not in metadata: + missing_fields.append(field) + logger.warning( + f"Optics field {field} not found, populating with default value {value}" + ) + metadata[field] = np.full(n_rows, value) + + _ensure_column("_rlnSphericalAberration", 0) + _ensure_column("_rlnVoltage", 0) + _ensure_column("_rlnAmplitudeContrast", 0) + + if missing_fields: + metadata["_aspireMetadata"] = np.full(n_rows, "no_ctf", dtype=object) + + # Restrict to the optics columns that are actually present on this source. + optics_value_fields = [ + field for field in all_optics_fields if field in metadata + ] + + # Map each unique optics tuple to a 1-based group ID in order encountered. + group_lookup = OrderedDict() + optics_groups = np.empty(n_rows, dtype=int) + + for idx in range(n_rows): + signature = tuple(metadata[field][idx] for field in optics_value_fields) + if signature not in group_lookup: + group_lookup[signature] = len(group_lookup) + 1 + optics_groups[idx] = group_lookup[signature] + + metadata["_rlnOpticsGroup"] = optics_groups + + # Build the optics block rows and assign group names. + optics_block = defaultdict(list) + + for signature, group_id in group_lookup.items(): + optics_block["_rlnOpticsGroup"].append(group_id) + optics_block["_rlnOpticsGroupName"].append(f"opticsGroup{group_id}") + for field, value in zip(optics_value_fields, signature): + optics_block[field].append(value) + + # Everything not lifted into the optics block stays with the particle metadata. + particle_block = OrderedDict() + if "_rlnOpticsGroup" in metadata: + particle_block["_rlnOpticsGroup"] = metadata["_rlnOpticsGroup"] + for key, value in metadata.items(): + if key in optics_value_fields or key == "_rlnOpticsGroup": + continue + particle_block[key] = value + + return optics_block, particle_block + def save_metadata(self, starfile_filepath, batch_size=512, save_mode=None): """ Save updated metadata to a STAR file @@ -1324,12 +1405,27 @@ def save_metadata(self, starfile_filepath, batch_size=512, save_mode=None): for x in np.char.split(metadata["_rlnImageName"].astype(np.str_), sep="@") ] + # Populate _rlnImageSize, _rlnImageDimensionality columns, required for optics_block below + if "_rlnImageSize" not in metadata: + metadata["_rlnImageSize"] = np.full(self.n, self.L, dtype=int) + + if "_rlnImageDimensionality" not in metadata: + metadata["_rlnImageDimensionality"] = np.full(self.n, 2, dtype=int) + + # Separate metadata into optics and particle blocks + optics_block, particle_block = self._prepare_relion_optics_blocks(metadata) + # initialize the star file object and save it odict = OrderedDict() - # since our StarFile only has one block, the convention is to save it with the header "data_", i.e. its name is blank - # if we had a block called "XYZ" it would be saved as "XYZ" - # thus we index the metadata block with "" - odict[""] = metadata + + # StarFile uses the `odict` keys to label the starfile block headers "data_(key)". Following RELION>=3.1 + # convention we label the blocks "data_optics" and "data_particles". + if optics_block is None: + odict["particles"] = particle_block + else: + odict["optics"] = optics_block + odict["particles"] = particle_block + out_star = StarFile(blocks=odict) out_star.write(starfile_filepath) return filename_indices @@ -1400,6 +1496,8 @@ def save_images( # for large arrays. stats.update_header(mrc) + # Add pixel size to header + mrc.voxel_size = self.pixel_size else: # save all images into multiple mrc files in batch size for i_start in np.arange(0, self.n, batch_size): @@ -1413,6 +1511,7 @@ def save_images( f"Saving ImageSource[{i_start}-{i_end-1}] to {mrcs_filepath}" ) im = self.images[i_start:i_end] + im.pixel_size = self.pixel_size im.save(mrcs_filepath, overwrite=overwrite) def estimate_signal_mean_energy( diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index c3620c9bd..c063c3b7b 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -125,6 +125,12 @@ def __init__( for key in offset_keys: del self._metadata[key] + # Detect ASPIRE-generated dummy variables + aspire_metadata = metadata.get("_aspireMetadata") + dummy_ctf = isinstance(aspire_metadata, (list, np.ndarray)) and np.all( + np.asarray(aspire_metadata) == "no_ctf" + ) + # CTF estimation parameters coming from Relion CTF_params = [ "_rlnVoltage", @@ -162,6 +168,14 @@ def __init__( # self.unique_filters of the filter that should be applied self.filter_indices = filter_indices + # If we detect ASPIRE added dummy variables, log and initialize identity filter + elif dummy_ctf: + logger.info( + "Detected ASPIRE-generated dummy optics; initializing identity filters." + ) + self.unique_filters = [IdentityFilter()] + self.filter_indices = np.zeros(self.n, dtype=int) + # We have provided some, but not all the required params elif any(param in metadata for param in CTF_params): logger.warning( diff --git a/src/aspire/utils/relion_interop.py b/src/aspire/utils/relion_interop.py index 996d32136..c807c7d86 100644 --- a/src/aspire/utils/relion_interop.py +++ b/src/aspire/utils/relion_interop.py @@ -20,7 +20,9 @@ "_rlnDetectorPixelSize": float, "_rlnCtfFigureOfMerit": float, "_rlnMagnification": float, + "_rlnImageDimensionality": int, "_rlnImagePixelSize": float, + "_rlnImageSize": int, "_rlnAmplitudeContrast": float, "_rlnImageName": str, "_rlnOriginalName": str, diff --git a/tests/test_array_image_source.py b/tests/test_array_image_source.py index 8dc2a28fb..a2c3ee4f0 100644 --- a/tests/test_array_image_source.py +++ b/tests/test_array_image_source.py @@ -323,10 +323,10 @@ def test_dtype_passthrough(dtype): # Check dtypes np.testing.assert_equal(src.dtype, dtype) np.testing.assert_equal(src.images[:].dtype, dtype) - np.testing.assert_equal(src.amplitudes.dtype, dtype) - # offsets are always stored as doubles + # offsets and amplitudes are always stored as doubles np.testing.assert_equal(src.offsets.dtype, np.float64) + np.testing.assert_equal(src.amplitudes.dtype, np.float64) def test_stack_1d_only(): diff --git a/tests/test_coordinate_source.py b/tests/test_coordinate_source.py index d2f93d5f6..8baf7d133 100644 --- a/tests/test_coordinate_source.py +++ b/tests/test_coordinate_source.py @@ -526,7 +526,7 @@ def testSave(self): # load saved particle stack saved_star = StarFile(star_path) # we want to read the saved mrcs file from the STAR file - image_name_column = saved_star.get_block_by_index(0)["_rlnImageName"] + image_name_column = saved_star.get_block_by_index(1)["_rlnImageName"] # we're reading a string of the form 0000X@mrcs_path.mrcs _particle, mrcs_path = image_name_column[0].split("@") saved_mrcs_stack = mrcfile.open(os.path.join(self.data_folder, mrcs_path)).data @@ -535,15 +535,31 @@ def testSave(self): self.assertTrue(np.array_equal(imgs.asnumpy()[i], saved_mrcs_stack[i])) # assert that the star file has the correct metadata self.assertEqual( - list(saved_star[""].keys()), + list(saved_star["particles"].keys()), [ - "_rlnImagePixelSize", + "_rlnOpticsGroup", "_rlnSymmetryGroup", "_rlnImageName", "_rlnCoordinateX", "_rlnCoordinateY", + "_aspireMetadata", + ], + ) + + self.assertEqual( + list(saved_star["optics"].keys()), + [ + "_rlnOpticsGroup", + "_rlnOpticsGroupName", + "_rlnImagePixelSize", + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnAmplitudeContrast", + "_rlnImageSize", + "_rlnImageDimensionality", ], ) + # assert that all the correct coordinates were saved for i in range(10): self.assertEqual( diff --git a/tests/test_relion_source.py b/tests/test_relion_source.py index 009ecd321..64703ed49 100644 --- a/tests/test_relion_source.py +++ b/tests/test_relion_source.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from aspire.source import RelionSource, Simulation +from aspire.source import ImageSource, RelionSource, Simulation from aspire.utils import RelionStarFile from aspire.volume import SymmetryGroup @@ -61,6 +61,41 @@ def test_symmetry_group(caplog): assert str(src_override_sym.symmetry_group) == "C6" +def test_prepare_relion_optics_blocks_warns(caplog): + """ + Test we warn when optics group metadata is missing. + """ + # metadata dict with no CTF values + metadata = { + "_rlnImagePixelSize": np.array([1.234]), + "_rlnImageSize": np.array([32]), + "_rlnImageDimensionality": np.array([2]), + "_rlnImageName": np.array(["000001@stack.mrcs"]), + } + + caplog.clear() + with caplog.at_level(logging.WARNING): + optics_block, particle_block = ImageSource._prepare_relion_optics_blocks( + metadata.copy() + ) + + # We should get and optics block + assert optics_block is not None + + # Verify defaults were injected. + np.testing.assert_allclose(optics_block["_rlnImagePixelSize"], [1.234]) + np.testing.assert_array_equal(optics_block["_rlnImageSize"], [32]) + np.testing.assert_array_equal(optics_block["_rlnImageDimensionality"], [2]) + np.testing.assert_allclose(optics_block["_rlnVoltage"], [0]) + np.testing.assert_allclose(optics_block["_rlnSphericalAberration"], [0]) + np.testing.assert_allclose(optics_block["_rlnAmplitudeContrast"], [0]) + + # Caplog should contain the warnings about the three missing fields. + assert "Optics field _rlnSphericalAberration not found" in caplog.text + assert "Optics field _rlnVoltage not found" in caplog.text + assert "Optics field _rlnAmplitudeContrast not found" in caplog.text + + def test_pixel_size(caplog): """ Instantiate RelionSource from starfiles containing the following pixel size diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 5859aae93..c053b2eb0 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -3,13 +3,14 @@ import tempfile from unittest import TestCase +import mrcfile import numpy as np import pytest from aspire.noise import WhiteNoiseAdder from aspire.operators import RadialCTFFilter from aspire.source import RelionSource, Simulation, _LegacySimulation -from aspire.utils import utest_tolerance +from aspire.utils import RelionStarFile, utest_tolerance from aspire.volume import LegacyVolume, SymmetryGroup, Volume from .test_utils import matplotlib_dry_run @@ -627,6 +628,111 @@ def testSimulationSaveFile(self): ) +def test_simulation_save_optics_block(tmp_path): + res = 32 + + # Radial CTF Filters. Should make 3 distinct optics blocks + kv_min, kv_max, kv_ct = 200, 300, 3 + voltages = np.linspace(kv_min, kv_max, kv_ct) + ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + + # Generate and save Simulation + sim = Simulation( + n=9, L=res, C=1, unique_filters=ctf_filters, pixel_size=1.34 + ).cache() + starpath = tmp_path / "sim.star" + sim.save(starpath, overwrite=True) + + star = RelionStarFile(str(starpath)) + assert star.relion_version == "3.1" + assert star.blocks.keys() == {"optics", "particles"} + + optics = star["optics"] + expected_optics_fields = [ + "_rlnOpticsGroup", + "_rlnOpticsGroupName", + "_rlnImagePixelSize", + "_rlnSphericalAberration", + "_rlnVoltage", + "_rlnAmplitudeContrast", + "_rlnImageSize", + "_rlnImageDimensionality", + ] + + # Check all required fields are present + for field in expected_optics_fields: + assert field in optics + + # Optics group and group name should be 1-indexed + np.testing.assert_array_equal( + optics["_rlnOpticsGroup"], np.arange(1, kv_ct + 1, dtype=int) + ) + np.testing.assert_array_equal( + optics["_rlnOpticsGroupName"], + np.array([f"opticsGroup{i}" for i in range(1, kv_ct + 1)]), + ) + + # Check image size (res) and image dimensionality (2) + np.testing.assert_array_equal(optics["_rlnImageSize"], np.full(kv_ct, res)) + np.testing.assert_array_equal(optics["_rlnImageDimensionality"], np.full(kv_ct, 2)) + + # Due to Simulation random indexing, voltages will be unordered + np.testing.assert_allclose(np.sort(optics["_rlnVoltage"]), voltages) + + # Check that each row of the data_particles block has an associated optics group + particles = star["particles"] + assert "_rlnOpticsGroup" in particles + assert len(particles["_rlnOpticsGroup"]) == sim.n + np.testing.assert_array_equal( + np.sort(np.unique(particles["_rlnOpticsGroup"])), + np.arange(1, kv_ct + 1, dtype=int), + ) + + # Test phase_flip after save/load round trip to ensure correct optics group mapping + rln_src = RelionSource(starpath) + np.testing.assert_allclose( + sim.phase_flip().images[:], rln_src.phase_flip().images[:] + ) + + +def test_simulation_slice_save_roundtrip(tmp_path): + # Radial CTF Filters + kv_min, kv_max, kv_ct = 200, 300, 3 + voltages = np.linspace(kv_min, kv_max, kv_ct) + ctf_filters = [RadialCTFFilter(voltage=kv) for kv in voltages] + + # Generate and save slice of Simulation + sim = Simulation(n=9, L=16, C=1, unique_filters=ctf_filters, pixel_size=1.34) + sliced_sim = sim[::2] + save_path = tmp_path / "sliced_sim.star" + sliced_sim.save(save_path, overwrite=True) + + # Load saved slice and compare to original + reloaded = RelionSource(save_path) + + # Check images + np.testing.assert_allclose( + reloaded.images[:].asnumpy(), + sliced_sim.images[:].asnumpy(), + ) + + # Check metadata related to optics block + metadata_fields = [ + "_rlnVoltage", + "_rlnDefocusU", + "_rlnDefocusV", + "_rlnDefocusAngle", + "_rlnSphericalAberration", + "_rlnAmplitudeContrast", + "_rlnImagePixelSize", + ] + for field in metadata_fields: + np.testing.assert_allclose( + reloaded.get_metadata(field), + sliced_sim.get_metadata(field), + ) + + def test_default_symmetry_group(): # Check that default is "C1". sim = Simulation() @@ -809,6 +915,53 @@ def test_save_overwrite(caplog): check_metadata(sim2, sim2_loaded_renamed) +def test_save_load_dummy_ctf_values(tmp_path, caplog): + """ + Test we populate optics group field with dummy values when none + are present. These values should be detected upon reloading the source. + """ + star_path = tmp_path / "no_ctf.star" + sim = Simulation(n=8, L=16) # no unique_filters, ie. no CTF info + sim.save(star_path, overwrite=True) + + # STAR file should contain our fallback tag + star = RelionStarFile(star_path) + particles_block = star.get_block_by_index(1) + np.testing.assert_array_equal( + particles_block["_aspireMetadata"], np.full(sim.n, "no_ctf", dtype=object) + ) + + # Tag should survive round-trip + caplog.clear() + reloaded = RelionSource(star_path) + np.testing.assert_array_equal( + reloaded._metadata["_aspireMetadata"], + np.full(reloaded.n, "no_ctf", dtype=object), + ) + + # Check message is logged about detecting dummy variables + assert "Detected ASPIRE-generated dummy optics" in caplog.text + + +@pytest.mark.parametrize("batch_size", [1, 6]) +def test_simulation_save_sets_voxel_size(tmp_path, batch_size): + """ + Test we save with pixel_size appended to the mrcfile header. + """ + # Note, n=6 and batch_size=6 exercises save_mode=='single' branch. + sim = Simulation(n=6, L=24, pixel_size=1.37) + info = sim.save(tmp_path / "pixel_size.star", batch_size=batch_size, overwrite=True) + + for stack_name in info["mrcs"]: + stack_path = tmp_path / stack_name + with mrcfile.open(stack_path, permissive=True) as f: + vs = f.voxel_size + header_vals = np.array( + [float(vs.x), float(vs.y), float(vs.z)], dtype=np.float64 + ) + np.testing.assert_allclose(header_vals, sim.pixel_size) + + def check_metadata(sim_src, relion_src): """ Helper function to test if metadata fields in a Simulation match