diff --git a/InnerEye/ML/run_ml.py b/InnerEye/ML/run_ml.py index 94c5f7bd8..743c9256b 100644 --- a/InnerEye/ML/run_ml.py +++ b/InnerEye/ML/run_ml.py @@ -44,7 +44,7 @@ from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.model_inference_config import ModelInferenceConfig from InnerEye.ML.model_testing import model_test -from InnerEye.ML.model_training import create_lightning_trainer, model_train +from InnerEye.ML.model_training import create_lightning_trainer, is_global_rank_zero, model_train from InnerEye.ML.reports.notebook_report import generate_classification_crossval_notebook, \ generate_classification_multilabel_notebook, generate_classification_notebook, generate_segmentation_notebook, \ get_ipynb_report_name, reports_folder @@ -222,7 +222,7 @@ def setup(self, use_mount_or_download_dataset: bool = True) -> None: azure_config=self.azure_config, project_root=self.project_root, run_context=RUN_CONTEXT) - self.checkpoint_handler.download_recovery_checkpoints_or_weights() + self.checkpoint_handler.download_recovery_checkpoints_or_weights(only_return_path=not is_global_rank_zero()) # A lot of the code for the built-in InnerEye models expects the output paths directly in the config files. if isinstance(self.container, InnerEyeContainer): diff --git a/InnerEye/ML/utils/checkpoint_handling.py b/InnerEye/ML/utils/checkpoint_handling.py index fe6b50686..7ddda99c9 100644 --- a/InnerEye/ML/utils/checkpoint_handling.py +++ b/InnerEye/ML/utils/checkpoint_handling.py @@ -57,15 +57,19 @@ def download_checkpoints_from_hyperdrive_child_runs(self, hyperdrive_parent_run: if not path.is_dir(): raise NotADirectoryError(f"Does not exist or is not a directory: {path}") - def download_recovery_checkpoints_or_weights(self) -> None: + def download_recovery_checkpoints_or_weights(self, only_return_path: bool = False) -> None: """ Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the run_recovery_object, weights_url or local_weights_path. This is called at the start of training. + :param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually + downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple + nodes. If False, return the RunRecovery object and download the checkpoint to disk. """ if self.azure_config.run_recovery_id: run_to_recover = self.azure_config.fetch_run(self.azure_config.run_recovery_id.strip()) - self.run_recovery = RunRecovery.download_all_checkpoints_from_run(self.output_params, run_to_recover) + self.run_recovery = RunRecovery.download_all_checkpoints_from_run(self.output_params, run_to_recover, + only_return_path=only_return_path) else: self.run_recovery = None @@ -73,7 +77,8 @@ def download_recovery_checkpoints_or_weights(self) -> None: run_to_recover = self.azure_config.fetch_run(self.azure_config.pretraining_run_recovery_id.strip()) run_recovery_object = RunRecovery.download_all_checkpoints_from_run(self.output_params, run_to_recover, - EXTRA_RUN_SUBFOLDER) + EXTRA_RUN_SUBFOLDER, + only_return_path=only_return_path) self.container.extra_downloaded_run_id = run_recovery_object else: self.container.extra_downloaded_run_id = None diff --git a/InnerEye/ML/utils/run_recovery.py b/InnerEye/ML/utils/run_recovery.py index f7acea153..bb92cdabe 100644 --- a/InnerEye/ML/utils/run_recovery.py +++ b/InnerEye/ML/utils/run_recovery.py @@ -63,13 +63,17 @@ def download_best_checkpoints_from_child_runs(config: OutputParams, run: Run) -> @staticmethod def download_all_checkpoints_from_run(config: OutputParams, run: Run, - subfolder: Optional[str] = None) -> RunRecovery: + subfolder: Optional[str] = None, + only_return_path: bool = False) -> RunRecovery: """ Downloads all checkpoints of the provided run inside the checkpoints folder. :param config: Model related configs. :param run: Run whose checkpoints should be recovered :param subfolder: optional subfolder name, if provided the checkpoints will be downloaded to CHECKPOINT_FOLDER / subfolder. If None, the checkpoint are downloaded to CHECKPOINT_FOLDER of the current run. + :param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually + downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple + nodes. If False, return the RunRecovery object and download the checkpoint to disk. :return: run recovery information """ if fetch_child_runs(run): @@ -77,11 +81,12 @@ def download_all_checkpoints_from_run(config: OutputParams, run: Run, destination_folder = config.checkpoint_folder / subfolder if subfolder else config.checkpoint_folder - download_outputs_from_run( - blobs_path=Path(CHECKPOINT_FOLDER), - destination=destination_folder, - run=run - ) + if not only_return_path: + download_outputs_from_run( + blobs_path=Path(CHECKPOINT_FOLDER), + destination=destination_folder, + run=run + ) time.sleep(60) # Needed because AML is not fast enough to download return RunRecovery(checkpoints_roots=[destination_folder]) diff --git a/Tests/AfterTraining/test_after_training.py b/Tests/AfterTraining/test_after_training.py index 3975647db..98271b5d0 100644 --- a/Tests/AfterTraining/test_after_training.py +++ b/Tests/AfterTraining/test_after_training.py @@ -15,6 +15,7 @@ import sys from pathlib import Path from typing import List +from unittest import mock import numpy as np import pytest @@ -29,7 +30,7 @@ from InnerEye.Common.common_util import CROSSVAL_RESULTS_FOLDER, ENSEMBLE_SPLIT_NAME, get_best_epoch_results_path from InnerEye.Common.fixed_paths import DEFAULT_AML_LOGS_DIR, DEFAULT_RESULT_IMAGE_NAME, \ DEFAULT_RESULT_ZIP_DICOM_NAME, \ - PYTHON_ENVIRONMENT_NAME + PYTHON_ENVIRONMENT_NAME, repository_root_directory from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path from InnerEye.Common.output_directories import OutputFolderForTests from InnerEye.Common.spawn_subprocess import spawn_and_monitor_subprocess @@ -38,6 +39,7 @@ from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, ModelCategory from InnerEye.ML.model_inference_config import read_model_inference_config from InnerEye.ML.reports.notebook_report import get_html_report_name +from InnerEye.ML.runner import main from InnerEye.ML.utils.config_loader import ModelConfigLoader from InnerEye.ML.utils.image_util import get_unit_image_header from InnerEye.ML.utils.io_util import zip_random_dicom_series @@ -351,3 +353,37 @@ def test_training_2nodes(test_output_dirs: OutputFolderForTests) -> None: assert "initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/4" in log0_txt assert "initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/4" in log1_txt assert "initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/4" in log1_txt + + +@pytest.mark.after_training_2node +def test_recovery_on_2_nodes(test_output_dirs: OutputFolderForTests) -> None: + args_list = ["--model", "BasicModel2EpochsMoreData", + "--azureml", "True", + "--num_nodes", "2", + "--run_recovery_id", + str(get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_2NODE_RUN)), + "--num_epochs", "4", + "--wait_for_completion", "True" + ] + script = str(repository_root_directory() / "InnerEye" / "ML" / "runner.py") + with mock.patch("sys.argv", [script] + args_list): + main() + run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_2NODE_RUN) + assert run.status == RunStatus.COMPLETED + files = run.get_file_names() + # There are two nodes, so there should be one log file per node. + log0_path = "azureml-logs/70_driver_log_0.txt" + log1_path = "azureml-logs/70_driver_log_1.txt" + assert log0_path in files, "Node rank 0 log file is missing" + assert log1_path in files, "Node rank 1 log file is missing" + # Download both log files and check their contents + log0 = test_output_dirs.root_dir / log0_path + log1 = test_output_dirs.root_dir / log1_path + run.download_file(log0_path, output_file_path=str(log0)) + run.download_file(log1_path, output_file_path=str(log1)) + log0_txt = log0.read_text() + log1_txt = log1.read_text() + assert "Downloading multiple files from run" in log0_txt + assert "Downloading multiple files from run" not in log1_txt + assert "Loading checkpoint that was created at (epoch = 2, global_step = 2)" in log0_txt + assert "Loading checkpoint that was created at (epoch = 2, global_step = 2)" in log1_txt