diff --git a/pyproject.toml b/pyproject.toml index 835d0b5e..ad51e355 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ GitHub = "https://github.com/DiamondLightSource/python-murfey" "clem.register_preprocessing_result" = "murfey.workflows.clem.register_preprocessing_results:run" "pato" = "murfey.workflows.notifications:notification_setup" "picked_particles" = "murfey.workflows.spa.picking:particles_picked" +"picked_tomogram" = "murfey.workflows.tomo.picking:picked_tomogram" "spa.flush_spa_preprocess" = "murfey.workflows.spa.flush_spa_preprocess:flush_spa_preprocess" [tool.setuptools] diff --git a/src/murfey/cli/inject_spa_processing.py b/src/murfey/cli/inject_spa_processing.py index 7a7f66bf..69f350c6 100644 --- a/src/murfey/cli/inject_spa_processing.py +++ b/src/murfey/cli/inject_spa_processing.py @@ -13,12 +13,12 @@ from murfey.util.config import get_machine_config, get_microscope, get_security_config from murfey.util.db import ( AutoProcProgram, + ClassificationFeedbackParameters, ClientEnvironment, DataCollection, DataCollectionGroup, Movie, ProcessingJob, - SPAFeedbackParameters, SPARelionParameters, ) from murfey.util.processing_params import default_spa_parameters @@ -137,9 +137,9 @@ def run(): .where(ProcessingJob.recipe == "em-spa-preprocess") ).one() params = murfey_db.exec( - select(SPARelionParameters, SPAFeedbackParameters) + select(SPARelionParameters, ClassificationFeedbackParameters) .where(SPARelionParameters.pj_id == collected_ids[2].id) - .where(SPAFeedbackParameters.pj_id == SPARelionParameters.pj_id) + .where(ClassificationFeedbackParameters.pj_id == SPARelionParameters.pj_id) ).one() proc_params: dict | None = dict(params[0]) feedback_params = params[1] diff --git a/src/murfey/cli/spa_ispyb_messages.py b/src/murfey/cli/spa_ispyb_messages.py index 640b585f..a183616c 100644 --- a/src/murfey/cli/spa_ispyb_messages.py +++ b/src/murfey/cli/spa_ispyb_messages.py @@ -363,7 +363,7 @@ def run(): small_boxsize=metadata["small_boxsize"], mask_diameter=metadata["mask_diameter"], ) - feedback_params = db.SPAFeedbackParameters( + feedback_params = db.ClassificationFeedbackParameters( pj_id=collected_ids[2].id, estimate_particle_diameter=not bool(metadata["particle_diameter"]), hold_class2d=False, diff --git a/src/murfey/client/contexts/tomo.py b/src/murfey/client/contexts/tomo.py index 7e6e1386..3439759a 100644 --- a/src/murfey/client/contexts/tomo.py +++ b/src/murfey/client/contexts/tomo.py @@ -161,7 +161,13 @@ def register_tomography_data_collections( data=dc_data, ) - for recipe in ("em-tomo-preprocess", "em-tomo-align"): + recipes_to_assign_pjids = [ + "em-tomo-preprocess", + "em-tomo-align", + ] + if not self._tilt_series_with_pjids: + recipes_to_assign_pjids.append("em-tomo-class2d") + for recipe in recipes_to_assign_pjids: capture_post( base_url=str(environment.url.geturl()), router_name="workflow.router", diff --git a/src/murfey/server/api/session_info.py b/src/murfey/server/api/session_info.py index 62f3be1f..fa10faec 100644 --- a/src/murfey/server/api/session_info.py +++ b/src/murfey/server/api/session_info.py @@ -31,6 +31,7 @@ from murfey.util import sanitise from murfey.util.config import MachineConfig, get_machine_config from murfey.util.db import ( + ClassificationFeedbackParameters, ClientEnvironment, DataCollection, DataCollectionGroup, @@ -41,7 +42,6 @@ RsyncInstance, Session, SessionProcessingParameters, - SPAFeedbackParameters, SPARelionParameters, Tilt, TiltSeries, @@ -280,7 +280,7 @@ class ProcessingDetails(BaseModel): data_collections: List[DataCollection] processing_jobs: List[ProcessingJob] relion_params: SPARelionParameters - feedback_params: SPAFeedbackParameters + feedback_params: ClassificationFeedbackParameters @spa_router.get("/sessions/{session_id}/spa_processing_parameters") @@ -293,13 +293,13 @@ def get_spa_proc_param_details( DataCollection, ProcessingJob, SPARelionParameters, - SPAFeedbackParameters, + ClassificationFeedbackParameters, ) .where(DataCollectionGroup.session_id == session_id) .where(DataCollectionGroup.id == DataCollection.dcg_id) .where(DataCollection.id == ProcessingJob.dc_id) .where(SPARelionParameters.pj_id == ProcessingJob.id) - .where(SPAFeedbackParameters.pj_id == ProcessingJob.id) + .where(ClassificationFeedbackParameters.pj_id == ProcessingJob.id) ).all() if not params: return None diff --git a/src/murfey/server/api/workflow.py b/src/murfey/server/api/workflow.py index 8fec45f5..ddd60e09 100644 --- a/src/murfey/server/api/workflow.py +++ b/src/murfey/server/api/workflow.py @@ -44,6 +44,7 @@ from murfey.util.config import get_machine_config from murfey.util.db import ( AutoProcProgram, + ClassificationFeedbackParameters, DataCollection, DataCollectionGroup, FoilHole, @@ -54,7 +55,6 @@ SearchMap, Session, SessionProcessingParameters, - SPAFeedbackParameters, SPARelionParameters, Tilt, TiltSeries, @@ -409,9 +409,9 @@ async def request_spa_preprocessing( .where(ProcessingJob.recipe == "em-spa-preprocess") ).one() params = db.exec( - select(SPARelionParameters, SPAFeedbackParameters) + select(SPARelionParameters, ClassificationFeedbackParameters) .where(SPARelionParameters.pj_id == collected_ids[2].id) - .where(SPAFeedbackParameters.pj_id == SPARelionParameters.pj_id) + .where(ClassificationFeedbackParameters.pj_id == SPARelionParameters.pj_id) ).one() proc_params: Optional[dict] = dict(params[0]) feedback_params = params[1] diff --git a/src/murfey/server/demo_api.py b/src/murfey/server/demo_api.py index c59cde2c..79fab4c0 100644 --- a/src/murfey/server/demo_api.py +++ b/src/murfey/server/demo_api.py @@ -49,6 +49,7 @@ ) from murfey.util.db import ( AutoProcProgram, + ClassificationFeedbackParameters, ClientEnvironment, DataCollection, DataCollectionGroup, @@ -60,7 +61,6 @@ ProcessingJob, RsyncInstance, Session, - SPAFeedbackParameters, SPARelionParameters, Tilt, TiltSeries, @@ -244,7 +244,7 @@ class ProcessingDetails(BaseModel): data_collections: List[DataCollection] processing_jobs: List[ProcessingJob] relion_params: SPARelionParameters - feedback_params: SPAFeedbackParameters + feedback_params: ClassificationFeedbackParameters @router.get("/sessions/{session_id}/spa_processing_parameters") @@ -257,13 +257,13 @@ def get_spa_proc_param_details( DataCollection, ProcessingJob, SPARelionParameters, - SPAFeedbackParameters, + ClassificationFeedbackParameters, ) .where(DataCollectionGroup.session_id == session_id) .where(DataCollectionGroup.id == DataCollection.dcg_id) .where(DataCollection.id == ProcessingJob.dc_id) .where(SPARelionParameters.pj_id == ProcessingJob.id) - .where(SPAFeedbackParameters.pj_id == ProcessingJob.id) + .where(ClassificationFeedbackParameters.pj_id == ProcessingJob.id) ).all() if not params: return None @@ -560,9 +560,9 @@ def flush_spa_processing( .where(ProcessingJob.recipe == "em-spa-preprocess") ).one() params = db.exec( - select(SPARelionParameters, SPAFeedbackParameters) + select(SPARelionParameters, ClassificationFeedbackParameters) .where(SPARelionParameters.pj_id == collected_ids[2].id) - .where(SPAFeedbackParameters.pj_id == SPARelionParameters.pj_id) + .where(ClassificationFeedbackParameters.pj_id == SPARelionParameters.pj_id) ).one() proc_params = dict(params[0]) feedback_params = params[1] diff --git a/src/murfey/server/feedback.py b/src/murfey/server/feedback.py index 288f3926..2f0712ae 100644 --- a/src/murfey/server/feedback.py +++ b/src/murfey/server/feedback.py @@ -313,13 +313,15 @@ def _pj_id(app_id: int, _db, recipe: str = "") -> int: def _get_spa_params( app_id: int, _db -) -> Tuple[db.SPARelionParameters, db.SPAFeedbackParameters]: +) -> Tuple[db.SPARelionParameters, db.ClassificationFeedbackParameters]: pj_id = _pj_id(app_id, _db, recipe="em-spa-preprocess") relion_params = _db.exec( select(db.SPARelionParameters).where(db.SPARelionParameters.pj_id == pj_id) ).one() feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where(db.SPAFeedbackParameters.pj_id == pj_id) + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id + ) ).one() _db.expunge(relion_params) _db.expunge(feedback_params) @@ -412,8 +414,8 @@ def _release_3d_hold(message: dict, _db): ) ).one() feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() class3d_params = _db.exec( @@ -490,8 +492,8 @@ def _release_refine_hold(message: dict, _db): ) ).one() feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() refine_params = _db.exec( @@ -582,8 +584,8 @@ def _register_incomplete_2d_batch(message: dict, _db, demo: bool = False): ) ).one() feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() if feedback_params.hold_class2d: @@ -623,7 +625,9 @@ def _register_incomplete_2d_batch(message: dict, _db, demo: bool = False): ) _db.add(class2d_params) _db.commit() - murfey_ids = _murfey_id(message["program_id"], _db, number=50) + murfey_ids = _murfey_id( + message["program_id"], _db, number=default_spa_parameters.nr_classes_2d + ) _murfey_class2ds( murfey_ids, class2d_message["particles_file"], message["program_id"], _db ) @@ -706,8 +710,8 @@ def _register_complete_2d_batch(message: dict, _db, demo: bool = False): ) ).one() feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() _db.expunge(relion_params) @@ -747,7 +751,9 @@ def _register_complete_2d_batch(message: dict, _db, demo: bool = False): _db.add(class2d_params) _db.commit() _db.close() - murfey_ids = _murfey_id(_app_id(pj_id, _db), _db, number=50) + murfey_ids = _murfey_id( + _app_id(pj_id, _db), _db, number=default_spa_parameters.nr_classes_2d + ) _murfey_class2ds( murfey_ids, class2d_message["particles_file"], _app_id(pj_id, _db), _db ) @@ -796,7 +802,13 @@ def _register_complete_2d_batch(message: dict, _db, demo: bool = False): else: class_uuids = { str(i + 1): m - for i, m in enumerate(_murfey_id(_app_id(pj_id, _db), _db, number=50)) + for i, m in enumerate( + _murfey_id( + _app_id(pj_id, _db), + _db, + number=default_spa_parameters.nr_classes_2d, + ) + ) } class2d_grp_uuid = _murfey_id(_app_id(pj_id, _db), _db)[0] zocalo_message: dict = { @@ -865,7 +877,13 @@ def _register_complete_2d_batch(message: dict, _db, demo: bool = False): else: class_uuids = { str(i + 1): m - for i, m in enumerate(_murfey_id(_app_id(pj_id, _db), _db, number=50)) + for i, m in enumerate( + _murfey_id( + _app_id(pj_id, _db), + _db, + number=default_spa_parameters.nr_classes_2d, + ) + ) } class2d_grp_uuid = _murfey_id(_app_id(pj_id, _db), _db)[0] zocalo_message = { @@ -913,7 +931,7 @@ def _flush_class2d( app_id: int, _db, relion_params: db.SPARelionParameters | None = None, - feedback_params: db.SPAFeedbackParameters | None = None, + feedback_params: db.ClassificationFeedbackParameters | None = None, ): instrument_name = ( _db.exec(select(db.Session).where(db.Session.id == session_id)) @@ -934,8 +952,8 @@ def _flush_class2d( _db.expunge(relion_params) if not feedback_params: feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() _db.expunge(feedback_params) @@ -1012,8 +1030,8 @@ def _register_class_selection(message: dict, _db, demo: bool = False): ).all() # Add the class selection score to the database feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() _db.expunge(feedback_params) @@ -1233,8 +1251,8 @@ def _register_3d_batch(message: dict, _db, demo: bool = False): ).one() relion_options = dict(relion_params) feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() other_options = dict(feedback_params) @@ -1411,8 +1429,8 @@ def _register_initial_model(message: dict, _db, demo: bool = False): pj_id_params = _pj_id(message["program_id"], _db, recipe="em-spa-preprocess") # Add the initial model file to the database feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() feedback_params.initial_model = message.get("initial_model") @@ -1578,8 +1596,8 @@ def _register_refinement(message: dict, _db, demo: bool = False): ).one() relion_options = dict(relion_params) feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() other_options = dict(feedback_params) @@ -1726,8 +1744,8 @@ def _register_bfactors(message: dict, _db, demo: bool = False): ).one() relion_options = dict(relion_params) feedback_params = _db.exec( - select(db.SPAFeedbackParameters).where( - db.SPAFeedbackParameters.pj_id == pj_id_params + select(db.ClassificationFeedbackParameters).where( + db.ClassificationFeedbackParameters.pj_id == pj_id_params ) ).one() @@ -2289,7 +2307,7 @@ def feedback_callback(header: dict, message: dict, _db=murfey_db) -> None: eer_fractionation_file=message["eer_fractionation_file"], symmetry=message["symmetry"], ) - feedback_params = db.SPAFeedbackParameters( + feedback_params = db.ClassificationFeedbackParameters( pj_id=collected_ids[2].id, estimate_particle_diameter=True, hold_class2d=False, @@ -2346,7 +2364,18 @@ def feedback_callback(header: dict, message: dict, _db=murfey_db) -> None: gain_ref=message["gain_ref"], eer_fractionation_file=message["eer_fractionation_file"], ) + feedback_params = db.ClassificationFeedbackParameters( + pj_id=collected_ids[2].id, + estimate_particle_diameter=True, + hold_class2d=False, + hold_class3d=False, + class_selection_score=0, + star_combination_job=0, + initial_model="", + next_job=0, + ) _db.add(params) + _db.add(feedback_params) _db.commit() _db.close() if murfey.server._transport_object: diff --git a/src/murfey/util/db.py b/src/murfey/util/db.py index ef160613..065df4c5 100644 --- a/src/murfey/util/db.py +++ b/src/murfey/util/db.py @@ -459,12 +459,18 @@ class ProcessingJob(SQLModel, table=True): # type: ignore spa_parameters: List["SPARelionParameters"] = Relationship( back_populates="processing_job", sa_relationship_kwargs={"cascade": "delete"} ) - spa_feedback_parameters: List["SPAFeedbackParameters"] = Relationship( - back_populates="processing_job", sa_relationship_kwargs={"cascade": "delete"} + classification_feedback_parameters: List["ClassificationFeedbackParameters"] = ( + Relationship( + back_populates="processing_job", + sa_relationship_kwargs={"cascade": "delete"}, + ) ) ctf_parameters: List["CtfParameters"] = Relationship( back_populates="processing_job", sa_relationship_kwargs={"cascade": "delete"} ) + tomogram_picks: List["TomogramPicks"] = Relationship( + back_populates="processing_job", sa_relationship_kwargs={"cascade": "delete"} + ) class2d_parameters: List["Class2DParameters"] = Relationship( back_populates="processing_job", sa_relationship_kwargs={"cascade": "delete"} ) @@ -514,6 +520,7 @@ class TomographyProcessingParameters(SQLModel, table=True): # type: ignore frame_count: int tilt_axis: float voltage: int + particle_diameter: Optional[float] = None eer_fractionation_file: Optional[str] = None motion_corr_binning: int = 1 gain_ref: Optional[str] = None @@ -557,8 +564,10 @@ class MurfeyLedger(SQLModel, table=True): # type: ignore refine_parameters: Optional["RefineParameters"] = Relationship( back_populates="murfey_ledger", sa_relationship_kwargs={"cascade": "delete"} ) - spa_feedback_parameters: Optional["SPAFeedbackParameters"] = Relationship( - back_populates="murfey_ledger", sa_relationship_kwargs={"cascade": "delete"} + classification_feedback_parameters: Optional["ClassificationFeedbackParameters"] = ( + Relationship( + back_populates="murfey_ledger", sa_relationship_kwargs={"cascade": "delete"} + ) ) movies: Optional["Movie"] = Relationship( back_populates="murfey_ledger", sa_relationship_kwargs={"cascade": "delete"} @@ -671,6 +680,17 @@ class CtfParameters(SQLModel, table=True): # type: ignore ) +class TomogramPicks(SQLModel, table=True): # type: ignore + tomogram: str = Field(primary_key=True) + pj_id: int = Field(foreign_key="processingjob.id") + cbox_3d: str + particle_count: int + tomogram_pixel_size: float + processing_job: Optional[ProcessingJob] = Relationship( + back_populates="tomogram_picks" + ) + + class ParticleSizes(SQLModel, table=True): # type: ignore id: Optional[int] = Field(default=None, primary_key=True) pj_id: int = Field(foreign_key="processingjob.id") @@ -700,7 +720,7 @@ class SPARelionParameters(SQLModel, table=True): # type: ignore ) -class SPAFeedbackParameters(SQLModel, table=True): # type: ignore +class ClassificationFeedbackParameters(SQLModel, table=True): # type: ignore pj_id: int = Field(primary_key=True, foreign_key="processingjob.id") estimate_particle_diameter: bool = True hold_class2d: bool = False @@ -714,10 +734,10 @@ class SPAFeedbackParameters(SQLModel, table=True): # type: ignore picker_murfey_id: Optional[int] = Field(default=None, foreign_key="murfeyledger.id") picker_ispyb_id: Optional[int] = None processing_job: Optional[ProcessingJob] = Relationship( - back_populates="spa_feedback_parameters" + back_populates="classification_feedback_parameters" ) murfey_ledger: Optional[MurfeyLedger] = Relationship( - back_populates="spa_feedback_parameters" + back_populates="classification_feedback_parameters" ) diff --git a/src/murfey/util/processing_params.py b/src/murfey/util/processing_params.py index 65a51c20..07c5a12c 100644 --- a/src/murfey/util/processing_params.py +++ b/src/murfey/util/processing_params.py @@ -79,3 +79,11 @@ class SPAParameters(BaseModel): default_spa_parameters = SPAParameters() + + +class TomographyParameters(BaseModel): + batch_size_2d: int = 10000 + nr_classes_2d: int = 5 + + +default_tomo_parameters = TomographyParameters() diff --git a/src/murfey/workflows/spa/flush_spa_preprocess.py b/src/murfey/workflows/spa/flush_spa_preprocess.py index 98dc8309..606511bb 100644 --- a/src/murfey/workflows/spa/flush_spa_preprocess.py +++ b/src/murfey/workflows/spa/flush_spa_preprocess.py @@ -13,6 +13,7 @@ from murfey.util.config import get_machine_config, get_microscope from murfey.util.db import ( AutoProcProgram, + ClassificationFeedbackParameters, DataCollection, DataCollectionGroup, FoilHole, @@ -21,7 +22,6 @@ PreprocessStash, ProcessingJob, Session as MurfeySession, - SPAFeedbackParameters, SPARelionParameters, ) from murfey.util.models import FoilHoleParameters, GridSquareParameters @@ -338,9 +338,9 @@ def flush_spa_preprocess(message: dict, murfey_db: Session, demo: bool = False) .where(ProcessingJob.recipe == recipe_name) ).one() params = murfey_db.exec( - select(SPARelionParameters, SPAFeedbackParameters) + select(SPARelionParameters, ClassificationFeedbackParameters) .where(SPARelionParameters.pj_id == collected_ids[2].id) - .where(SPAFeedbackParameters.pj_id == SPARelionParameters.pj_id) + .where(ClassificationFeedbackParameters.pj_id == SPARelionParameters.pj_id) ).one() proc_params = params[0] feedback_params = params[1] diff --git a/src/murfey/workflows/spa/picking.py b/src/murfey/workflows/spa/picking.py index 72a1e3ba..894a5f48 100644 --- a/src/murfey/workflows/spa/picking.py +++ b/src/murfey/workflows/spa/picking.py @@ -17,6 +17,7 @@ from murfey.util.config import get_machine_config from murfey.util.db import ( AutoProcProgram, + ClassificationFeedbackParameters, CtfParameters, DataCollection, Movie, @@ -26,7 +27,6 @@ ProcessingJob, SelectionStash, Session as MurfeySession, - SPAFeedbackParameters, SPARelionParameters, ) from murfey.util.processing_params import default_spa_parameters @@ -78,7 +78,9 @@ def _register_picked_particles_use_diameter( ).one() relion_options = dict(relion_params) feedback_params = _db.exec( - select(SPAFeedbackParameters).where(SPAFeedbackParameters.pj_id == pj_id) + select(ClassificationFeedbackParameters).where( + ClassificationFeedbackParameters.pj_id == pj_id + ) ).one() particle_diameter = relion_params.particle_diameter @@ -263,7 +265,9 @@ def _register_picked_particles_use_boxsize(message: dict, _db: Session): select(SPARelionParameters).where(SPARelionParameters.pj_id == pj_id) ).one() feedback_params = _db.exec( - select(SPAFeedbackParameters).where(SPAFeedbackParameters.pj_id == pj_id) + select(ClassificationFeedbackParameters).where( + ClassificationFeedbackParameters.pj_id == pj_id + ) ).one() if feedback_params.picker_ispyb_id is None and _transport_object: @@ -448,8 +452,9 @@ def particles_picked(message: dict, murfey_db: Session) -> bool: murfey_db.add(movie) murfey_db.commit() feedback_params = murfey_db.exec( - select(SPAFeedbackParameters).where( - SPAFeedbackParameters.pj_id == _pj_id(message["program_id"], murfey_db) + select(ClassificationFeedbackParameters).where( + ClassificationFeedbackParameters.pj_id + == _pj_id(message["program_id"], murfey_db) ) ).one() if feedback_params.estimate_particle_diameter: diff --git a/src/murfey/workflows/tomo/picking.py b/src/murfey/workflows/tomo/picking.py new file mode 100644 index 00000000..414283fc --- /dev/null +++ b/src/murfey/workflows/tomo/picking.py @@ -0,0 +1,208 @@ +from logging import getLogger +from typing import Tuple + +import numpy as np +from sqlalchemy import func +from sqlmodel import Session, select + +from murfey.server import _transport_object +from murfey.server.feedback import _app_id, _murfey_id +from murfey.util.config import get_machine_config +from murfey.util.db import ( + AutoProcProgram, + ClassificationFeedbackParameters, + DataCollection, + ParticleSizes, + ProcessingJob, + Session as MurfeySession, + TomogramPicks, + TomographyProcessingParameters, +) +from murfey.util.processing_params import default_tomo_parameters + +logger = getLogger("murfey.workflows.tomo.feedback") + + +def _ids_tomo_classification(app_id: int, recipe: str, _db) -> Tuple[int, int]: + dcg_id = ( + _db.exec( + select(AutoProcProgram, ProcessingJob, DataCollection) + .where(AutoProcProgram.id == app_id) + .where(AutoProcProgram.pj_id == ProcessingJob.id) + .where(ProcessingJob.dc_id == DataCollection.id) + ) + .one()[2] + .dcg_id + ) + pj_id = ( + _db.exec( + select(ProcessingJob, DataCollection) + .where(DataCollection.dcg_id == dcg_id) + .where(ProcessingJob.dc_id == DataCollection.id) + .where(ProcessingJob.recipe == recipe) + ) + .one()[0] + .id + ) + return dcg_id, pj_id + + +def _register_picked_tomogram_use_diameter(message: dict, _db: Session): + """Received picked particles from the tomogram autopick service""" + # Add this message to the table of seen messages + dcg_id, pj_id = _ids_tomo_classification( + message["program_id"], "em-tomo-class2d", _db + ) + + pick_params = TomogramPicks( + pj_id=pj_id, + tomogram=message["tomogram"], + cbox_3d=message["cbox_3d"], + particle_count=message["particle_count"], + tomogram_pixel_size=message["pixel_size"], + ) + _db.add(pick_params) + _db.commit() + + picking_db_len = _db.exec( + select(func.count(ParticleSizes.id)).where(ParticleSizes.pj_id == pj_id) + ).one() + if picking_db_len > default_tomo_parameters.batch_size_2d: + # If there are enough particles to get a diameter + instrument_name = ( + _db.exec( + select(MurfeySession).where(MurfeySession.id == message["session_id"]) + ) + .one() + .instrument_name + ) + machine_config = get_machine_config(instrument_name=instrument_name)[ + instrument_name + ] + tomo_params = _db.exec( + select(TomographyProcessingParameters).where( + TomographyProcessingParameters.dcg_id == dcg_id + ) + ).one() + + particle_diameter = tomo_params.particle_diameter + + feedback_params = _db.exec( + select(ClassificationFeedbackParameters).where( + ClassificationFeedbackParameters.pj_id == pj_id + ) + ).one() + if not feedback_params.next_job: + feedback_params.next_job = 9 + + if not particle_diameter: + # If the diameter has not been calculated then find it + picking_db = _db.exec( + select(ParticleSizes.particle_size).where(ParticleSizes.pj_id == pj_id) + ).all() + particle_diameter = np.quantile(list(picking_db), 0.75) + tomo_params.particle_diameter = particle_diameter + _db.add(tomo_params) + _db.commit() + + tomo_pick_db = _db.exec( + select(TomogramPicks).where(TomogramPicks.pj_id == pj_id) + ).all() + for saved_message in tomo_pick_db: + # Send on all saved messages to extraction + class_uuids = { + str(i + 1): m + for i, m in enumerate( + _murfey_id( + _app_id(pj_id, _db), + _db, + number=default_tomo_parameters.nr_classes_2d, + ) + ) + } + class2d_grp_uuid = _murfey_id(_app_id(pj_id, _db), _db)[0] + zocalo_message: dict = { + "parameters": { + "tomogram": saved_message.tomogram, + "cbox_3d": saved_message.cbox_3d, + "pixel_size": saved_message.tomogram_pixel_size, + "particle_diameter": particle_diameter, + "kv": tomo_params.voltage, + "node_creator_queue": machine_config.node_creator_queue, + "session_id": message["session_id"], + "autoproc_program_id": _app_id(pj_id, _db), + "batch_size": default_tomo_parameters.batch_size_2d, + "nr_classes": default_tomo_parameters.nr_classes_2d, + "picker_id": None, + "class2d_grp_uuid": class2d_grp_uuid, + "class_uuids": class_uuids, + "next_job": feedback_params.next_job, + }, + "recipes": ["em-tomo-class2d"], + } + if _transport_object: + zocalo_message["parameters"]["feedback_queue"] = ( + _transport_object.feedback_queue + ) + _transport_object.send( + "processing_recipe", zocalo_message, new_connection=True + ) + feedback_params.next_job += 2 + _db.delete(saved_message) + else: + # If the diameter is known then just send the new message + particle_diameter = tomo_params.particle_diameter + class_uuids = { + str(i + 1): m + for i, m in enumerate( + _murfey_id( + _app_id(pj_id, _db), + _db, + number=default_tomo_parameters.nr_classes_2d, + ) + ) + } + class2d_grp_uuid = _murfey_id(_app_id(pj_id, _db), _db)[0] + zocalo_message = { + "parameters": { + "tomogram": message["tomogram"], + "cbox_3d": message["cbox_3d"], + "pixel_size": message["pixel_size"], + "particle_diameter": particle_diameter, + "kv": tomo_params.voltage, + "node_creator_queue": machine_config.node_creator_queue, + "session_id": message["session_id"], + "autoproc_program_id": _app_id(pj_id, _db), + "batch_size": default_tomo_parameters.batch_size_2d, + "nr_classes": default_tomo_parameters.nr_classes_2d, + "picker_id": None, + "class2d_grp_uuid": class2d_grp_uuid, + "class_uuids": class_uuids, + "next_job": feedback_params.next_job, + }, + "recipes": ["em-tomo-class2d"], + } + if _transport_object: + zocalo_message["parameters"]["feedback_queue"] = ( + _transport_object.feedback_queue + ) + _transport_object.send( + "processing_recipe", zocalo_message, new_connection=True + ) + feedback_params.next_job += 2 + _db.add(feedback_params) + _db.commit() + else: + # If not enough particles then save the new sizes + particle_list = message.get("particle_diameters") + assert isinstance(particle_list, list) + for particle in particle_list: + new_particle = ParticleSizes(pj_id=pj_id, particle_size=particle) + _db.add(new_particle) + _db.commit() + _db.close() + + +def picked_tomogram(message: dict, murfey_db: Session) -> bool: + _register_picked_tomogram_use_diameter(message, murfey_db) + return True diff --git a/tests/workflows/tomo/test_tomo_picking.py b/tests/workflows/tomo/test_tomo_picking.py new file mode 100644 index 00000000..a2748012 --- /dev/null +++ b/tests/workflows/tomo/test_tomo_picking.py @@ -0,0 +1,348 @@ +from unittest import mock + +from sqlmodel import Session, select + +from murfey.util.db import ( + AutoProcProgram, + ClassificationFeedbackParameters, + DataCollection, + DataCollectionGroup, + ParticleSizes, + ProcessingJob, + TomogramPicks, + TomographyProcessingParameters, +) +from murfey.workflows.tomo import picking +from tests.conftest import ExampleVisit, get_or_create_db_entry + + +def set_up_picking_db(murfey_db_session: Session): + # Insert common elements needed in all picking tests + dcg_entry: DataCollectionGroup = get_or_create_db_entry( + murfey_db_session, + DataCollectionGroup, + lookup_kwargs={ + "id": 0, + "session_id": ExampleVisit.murfey_session_id, + "tag": "test_dcg", + }, + ) + dc_entry: DataCollection = get_or_create_db_entry( + murfey_db_session, + DataCollection, + lookup_kwargs={ + "id": 0, + "tag": "test_dc", + "dcg_id": dcg_entry.id, + }, + ) + processing_job_entry: ProcessingJob = get_or_create_db_entry( + murfey_db_session, + ProcessingJob, + lookup_kwargs={ + "id": 1, + "recipe": "test_recipe", + "dc_id": dc_entry.id, + }, + ) + get_or_create_db_entry( + murfey_db_session, + AutoProcProgram, + lookup_kwargs={ + "id": 0, + "pj_id": processing_job_entry.id, + }, + ) + get_or_create_db_entry( + murfey_db_session, + ClassificationFeedbackParameters, + lookup_kwargs={ + "pj_id": processing_job_entry.id, + "estimate_particle_diameter": True, + "hold_class2d": False, + "hold_class3d": False, + "class_selection_score": 0, + "star_combination_job": 0, + "initial_model": "", + "next_job": 0, + }, + ) + return dcg_entry.id, dc_entry.id, processing_job_entry.id + + +def test_ids_tomo_classification(murfey_db_session: Session): + dcg_id, first_dc_id, first_pj_id = set_up_picking_db(murfey_db_session) + + # Insert a second data collection, processing job and autoproc program + second_dc: DataCollection = get_or_create_db_entry( + murfey_db_session, + DataCollection, + lookup_kwargs={ + "id": 1, + "tag": "second_dc", + "dcg_id": dcg_id, + }, + ) + second_pj: ProcessingJob = get_or_create_db_entry( + murfey_db_session, + ProcessingJob, + lookup_kwargs={ + "id": 10, + "recipe": "second_recipe", + "dc_id": second_dc.id, + }, + ) + get_or_create_db_entry( + murfey_db_session, + AutoProcProgram, + lookup_kwargs={ + "id": 11, + "pj_id": second_pj.id, + }, + ) + + returned_ids = picking._ids_tomo_classification( + 11, "test_recipe", murfey_db_session + ) + assert returned_ids[0] == dcg_id + assert returned_ids[1] == first_pj_id + + +@mock.patch("murfey.workflows.tomo.picking._transport_object") +@mock.patch("murfey.workflows.tomo.picking._ids_tomo_classification") +def test_picked_tomogram_not_run_class2d( + mock_ids, mock_transport, murfey_db_session: Session, tmp_path +): + """Run the picker feedback with less particles than needed for classification""" + mock_ids.return_value = [2, 1] + + # Insert table dependencies + set_up_picking_db(murfey_db_session) + + message = { + "program_id": 0, + "cbox_3d": f"{tmp_path}/AutoPick/job007/CBOX_3d/sample.cbox", + "particle_count": 2, + "particle_diameters": [10.1, 20.2], + "pixel_size": 5.3, + "register": "picked_tomogram", + "tomogram": f"{tmp_path}/Tomograms/job006/tomograms/sample.mrc", + } + picking._register_picked_tomogram_use_diameter(message, murfey_db_session) + + mock_ids.assert_called_once_with(0, "em-tomo-class2d", murfey_db_session) + + tomograms_db = murfey_db_session.exec( + select(TomogramPicks).where(TomogramPicks.pj_id == 1) + ).one() + assert tomograms_db.tomogram == message["tomogram"] + assert tomograms_db.cbox_3d == message["cbox_3d"] + assert tomograms_db.particle_count == 2 + assert tomograms_db.tomogram_pixel_size == 5.3 + + added_picks = murfey_db_session.exec( + select(ParticleSizes).where(ParticleSizes.pj_id == 1) + ).all() + assert len(added_picks) == 2 + assert added_picks[0].particle_size == 10.1 + assert added_picks[1].particle_size == 20.2 + + mock_transport.send.assert_not_called() + + +@mock.patch("murfey.workflows.tomo.picking._transport_object") +@mock.patch("murfey.workflows.tomo.picking._ids_tomo_classification") +def test_picked_tomogram_run_class2d_with_diameter( + mock_ids, mock_transport, murfey_db_session: Session, tmp_path +): + """Run the picker feedback with a pre-determined particle diameter""" + mock_transport.feedback_queue = "murfey_feedback" + + # Insert table dependencies + dcg_id, dc_id, pj_id = set_up_picking_db(murfey_db_session) + get_or_create_db_entry( + murfey_db_session, + TomographyProcessingParameters, + lookup_kwargs={ + "dcg_id": dcg_id, + "pixel_size": 1.34, + "dose_per_frame": 1, + "frame_count": 5, + "tilt_axis": 0, + "voltage": 300, + "particle_diameter": 200, + }, + ) + for particle in range(10001): + get_or_create_db_entry( + murfey_db_session, + ParticleSizes, + lookup_kwargs={ + "id": particle, + "pj_id": pj_id, + "particle_size": 100, + }, + ) + + mock_ids.return_value = [dcg_id, 1] + + message = { + "session_id": 1, + "program_id": 0, + "cbox_3d": f"{tmp_path}/AutoPick/job007/CBOX_3d/sample.cbox", + "particle_count": 2, + "particle_diameters": [10.1, 20.2], + "pixel_size": 5.3, + "register": "picked_tomogram", + "tomogram": f"{tmp_path}/Tomograms/job006/tomograms/sample.mrc", + } + picking._register_picked_tomogram_use_diameter(message, murfey_db_session) + + mock_ids.assert_called_once_with(0, "em-tomo-class2d", murfey_db_session) + + tomograms_db = murfey_db_session.exec( + select(TomogramPicks).where(TomogramPicks.pj_id == 1) + ).one() + assert tomograms_db.tomogram == message["tomogram"] + assert tomograms_db.cbox_3d == message["cbox_3d"] + assert tomograms_db.particle_count == 2 + assert tomograms_db.tomogram_pixel_size == 5.3 + + mock_transport.send.assert_called_once_with( + "processing_recipe", + { + "parameters": { + "tomogram": message["tomogram"], + "cbox_3d": message["cbox_3d"], + "pixel_size": message["pixel_size"], + "particle_diameter": 200.0, + "kv": 300, + "node_creator_queue": "node_creator", + "session_id": message["session_id"], + "autoproc_program_id": 0, + "batch_size": 10000, + "nr_classes": 5, + "picker_id": None, + "class2d_grp_uuid": 6, + "class_uuids": {str(i): i for i in range(1, 6)}, + "next_job": 9, + "feedback_queue": "murfey_feedback", + }, + "recipes": ["em-tomo-class2d"], + }, + new_connection=True, + ) + + +@mock.patch("murfey.workflows.tomo.picking._transport_object") +@mock.patch("murfey.workflows.tomo.picking._ids_tomo_classification") +def test_picked_tomogram_run_class2d_estimate_diameter( + mock_ids, mock_transport, murfey_db_session: Session, tmp_path +): + """Run the picker feedback for Class2D, including diameter estimation""" + mock_transport.feedback_queue = "murfey_feedback" + + # Insert table dependencies + dcg_id, dc_id, pj_id = set_up_picking_db(murfey_db_session) + get_or_create_db_entry( + murfey_db_session, + TomographyProcessingParameters, + lookup_kwargs={ + "dcg_id": dcg_id, + "pixel_size": 1.34, + "dose_per_frame": 1, + "frame_count": 5, + "tilt_axis": 0, + "voltage": 300, + "particle_diameter": None, + }, + ) + for particle in range(10001): + get_or_create_db_entry( + murfey_db_session, + ParticleSizes, + lookup_kwargs={ + "id": particle, + "pj_id": pj_id, + "particle_size": 100, + }, + ) + # Insert one existing tomogram which should get flushed out + get_or_create_db_entry( + murfey_db_session, + TomogramPicks, + lookup_kwargs={ + "pj_id": pj_id, + "tomogram": f"{tmp_path}/Tomograms/job006/tomograms/tomogram1.mrc", + "cbox_3d": f"{tmp_path}/AutoPick/job007/CBOX_3d/tomogram1_picks.cbox", + "particle_count": 10, + "tomogram_pixel_size": 5.3, + }, + ) + + mock_ids.return_value = [dcg_id, 1] + + message = { + "session_id": 1, + "program_id": 0, + "cbox_3d": f"{tmp_path}/AutoPick/job007/CBOX_3d/sample.cbox", + "particle_count": 2, + "particle_diameters": [10.1, 20.2], + "pixel_size": 5.3, + "register": "picked_tomogram", + "tomogram": f"{tmp_path}/Tomograms/job006/tomograms/sample.mrc", + } + picking._register_picked_tomogram_use_diameter(message, murfey_db_session) + + mock_ids.assert_called_once_with(0, "em-tomo-class2d", murfey_db_session) + + # Two mock calls - one flushed tomogram and one new + assert mock_transport.send.call_count == 2 + mock_transport.send.assert_any_call( + "processing_recipe", + { + "parameters": { + "tomogram": f"{tmp_path}/Tomograms/job006/tomograms/tomogram1.mrc", + "cbox_3d": f"{tmp_path}/AutoPick/job007/CBOX_3d/tomogram1_picks.cbox", + "pixel_size": 5.3, + "particle_diameter": 100.0, + "kv": 300, + "node_creator_queue": "node_creator", + "session_id": message["session_id"], + "autoproc_program_id": 0, + "batch_size": 10000, + "nr_classes": 5, + "picker_id": None, + "class2d_grp_uuid": 12, + "class_uuids": {str(i): i + 6 for i in range(1, 6)}, + "next_job": 9, + "feedback_queue": "murfey_feedback", + }, + "recipes": ["em-tomo-class2d"], + }, + new_connection=True, + ) + mock_transport.send.assert_any_call( + "processing_recipe", + { + "parameters": { + "tomogram": message["tomogram"], + "cbox_3d": message["cbox_3d"], + "pixel_size": message["pixel_size"], + "particle_diameter": 100.0, + "kv": 300, + "node_creator_queue": "node_creator", + "session_id": message["session_id"], + "autoproc_program_id": 0, + "batch_size": 10000, + "nr_classes": 5, + "picker_id": None, + "class2d_grp_uuid": 18, + "class_uuids": {str(i): i + 12 for i in range(1, 6)}, + "next_job": 11, + "feedback_queue": "murfey_feedback", + }, + "recipes": ["em-tomo-class2d"], + }, + new_connection=True, + )