Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion InnerEye/ML/configs/segmentation/BasicModel2Epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(self, **kwargs: Any) -> None:
use_mixed_precision=True,
azure_dataset_id=AZURE_DATASET_ID,
comparison_blob_storage_paths=comparison_blob_storage_paths,
inference_on_test_set=True,
inference_on_val_set=True,
inference_on_test_set=True,
dataset_mountpoint="/tmp/innereye",
# Use an LR scheduler with a pronounced and clearly visible decay, to be able to easily see if that
# is applied correctly in run recovery.
Expand All @@ -65,3 +65,14 @@ def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> Datas
test_ids=['5'],
val_ids=['2']
)


class BasicModelForEnsembleTest(BasicModel2Epochs):
"""
A copy of the basic model for PR builds, to use for running in a cross validation job.
"""

def __init__(self) -> None:
super().__init__()
# Skip inference on the validation set, to test if missing files are handled correctly
self.inference_on_val_set = None
27 changes: 18 additions & 9 deletions InnerEye/ML/visualizers/plot_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def get_short_name(self, run_or_id: Union[Run, str]) -> str:
return self.short_names[run_id]

def execution_modes_to_download(self) -> List[ModelExecutionMode]:
"""
Returns the dataset splits (Train/Val/Test) for which results should be downloaded from the
cross validation child runs.
"""
if self.model_category.is_scalar:
return [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL, ModelExecutionMode.TEST]
else:
Expand Down Expand Up @@ -190,7 +194,8 @@ def download_or_get_local_file(self,
local_src_subdir: Optional[Path] = None) -> Optional[Path]:
"""
Downloads a file from the results folder of an AzureML run, or copies it from a local results folder.
Returns the path to the downloaded file if it exists, or None if the file was not found.
Returns the path to the downloaded file if it exists, or None if the file was not found, or could for other
reasons not be downloaded.
If the blobs_path contains folders, the same folder structure will be created inside the destination folder.
For example, downloading "foo.txt" to "/c/temp" will create "/c/temp/foo.txt". Downloading "foo/bar.txt"
to "/c/temp" will create "/c/temp/foo/bar.txt"
Expand Down Expand Up @@ -231,11 +236,14 @@ def download_or_get_local_file(self,
return Path(shutil.copy(local_src, destination))
return None
else:
return download_run_output_file(
blob_path=blob_path,
destination=destination,
run=run
)
try:
return download_run_output_file(
blob_path=blob_path,
destination=destination,
run=run
)
except Exception:
return None


@dataclass(frozen=True)
Expand Down Expand Up @@ -441,8 +449,8 @@ def crossval_config_from_model_config(train_config: DeepLearningConfig) -> PlotC

def get_config_and_results_for_offline_runs(train_config: DeepLearningConfig) -> OfflineCrossvalConfigAndFiles:
"""
Creates a configuration for crossvalidation analysis for the given model training configuration, and gets
the input files required for crossvalidation analysis.
Creates a configuration for cross validation analysis for the given model training configuration, and gets
the input files required for cross validation analysis.
:param train_config: The model configuration to work with.
"""
plot_crossval_config = crossval_config_from_model_config(train_config)
Expand Down Expand Up @@ -674,7 +682,8 @@ def save_outliers(config: PlotCrossValidationConfig,

f.write(f"\n\n=== METRIC: {metric_type} ===\n\n")
if len(outliers) > 0:
# If running inside institution there may be no CSV_SERIES_HEADER or CSV_INSTITUTION_HEADER columns
# If running inside institution there may be no CSV_SERIES_HEADER or CSV_INSTITUTION_HEADER
# columns
groupby_columns = [MetricsFileColumns.Patient.value, MetricsFileColumns.Structure.value]
if CSV_SERIES_HEADER in outliers.columns:
groupby_columns.append(CSV_SERIES_HEADER)
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"model_name": "BasicModel2Epochs", "checkpoint_paths": ["checkpoints/OTHER_RUNS/1/best_checkpoint.ckpt", "checkpoints/best_checkpoint.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.segmentation.BasicModel2Epochs"}
{"model_name": "BasicModelForEnsembleTest", "checkpoint_paths": ["checkpoints/OTHER_RUNS/1/best_checkpoint.ckpt", "checkpoints/best_checkpoint.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.segmentation.BasicModel2Epochs"}
42 changes: 20 additions & 22 deletions Tests/AfterTraining/test_after_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,26 @@
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.Common.spawn_subprocess import spawn_and_monitor_subprocess
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode
from InnerEye.ML.configs.segmentation.BasicModel2Epochs import BasicModel2Epochs
from InnerEye.ML.configs.other.HelloContainer import HelloContainer
from InnerEye.ML.configs.segmentation.BasicModel2Epochs import BasicModel2Epochs
from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, ModelCategory
from InnerEye.ML.model_inference_config import read_model_inference_config
from InnerEye.ML.model_testing import THUMBNAILS_FOLDER
from InnerEye.ML.reports.notebook_report import get_html_report_name
from InnerEye.ML.runner import main
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.runner import main
from InnerEye.ML.utils.config_loader import ModelConfigLoader
from InnerEye.ML.utils.image_util import get_unit_image_header
from InnerEye.ML.utils.io_util import zip_random_dicom_series
from InnerEye.ML.visualizers.plot_cross_validation import PlotCrossValidationConfig
from InnerEye.Scripts import submit_for_inference
from Tests.ML.util import assert_nifti_content, get_default_azure_config, get_nifti_shape, get_default_workspace
from Tests.ML.util import assert_nifti_content, get_default_azure_config, get_default_workspace, get_nifti_shape

FALLBACK_SINGLE_RUN = "refs_pull_498_merge:refs_pull_498_merge_1624292750_743430ab"
FALLBACK_ENSEMBLE_RUN = "refs_pull_498_merge:HD_4bf4efc3-182a-4596-8f93-76f128418142"
FALLBACK_2NODE_RUN = "refs_pull_498_merge:refs_pull_498_merge_1624292776_52b2f7e1"
FALLBACK_CV_GLAUCOMA = "refs_pull_498_merge:HD_cefb6e59-3929-43aa-8fc8-821b9a062219"
FALLBACK_HELLO_CONTAINER_RUN = "refs_pull_498_merge:refs_pull_498_merge_1624292748_45756bf8"
FALLBACK_SINGLE_RUN = "refs_pull_545_merge:refs_pull_545_merge_1626538212_d2b07afd"
FALLBACK_ENSEMBLE_RUN = "refs_pull_545_merge:HD_caea82ae-9603-48ba-8280-7d2bc6272411"
FALLBACK_2NODE_RUN = "refs_pull_545_merge:refs_pull_545_merge_1626538178_9f3023b2"
FALLBACK_CV_GLAUCOMA = "refs_pull_545_merge:HD_72ecc647-07c3-4353-a538-620346114ebd"
FALLBACK_HELLO_CONTAINER_RUN = "refs_pull_545_merge:refs_pull_545_merge_1626538216_3eb92f09"


def get_most_recent_run_id(fallback_run_id_for_local_execution: str = FALLBACK_SINGLE_RUN) -> str:
Expand Down Expand Up @@ -172,7 +172,6 @@ def test_registered_model_file_structure_and_instantiate(test_output_dirs: Outpu
model_name = tags["model_name"]
assert model_inference_config.model_name == model_name
assert model_inference_config.model_configs_namespace.startswith("InnerEye.ML.configs.")
assert model_inference_config.model_configs_namespace.endswith(model_name)
loader = ModelConfigLoader(model_configs_namespace=model_inference_config.model_configs_namespace)
model_config = loader.create_model_config_from_name(model_name=model_inference_config.model_name)
assert type(model_config).__name__ == model_inference_config.model_name
Expand Down Expand Up @@ -274,8 +273,8 @@ def test_expected_cv_files_segmentation() -> None:
assert run is not None
available_files = run.get_file_names()
for split in ["0", "1"]:
for mode in [ModelExecutionMode.TEST, ModelExecutionMode.VAL]:
assert _check_presence_cross_val_metrics_file(split, mode, available_files)
assert _check_presence_cross_val_metrics_file(split, ModelExecutionMode.TEST, available_files)
assert not _check_presence_cross_val_metrics_file(split, ModelExecutionMode.VAL, available_files)
# For ensemble we should have the test metrics only
assert _check_presence_cross_val_metrics_file(ENSEMBLE_SPLIT_NAME, ModelExecutionMode.TEST, available_files)
assert not _check_presence_cross_val_metrics_file(ENSEMBLE_SPLIT_NAME, ModelExecutionMode.VAL, available_files)
Expand Down Expand Up @@ -463,9 +462,10 @@ def test_download_non_existing_file(test_output_dirs: OutputFolderForTests) -> N


@pytest.mark.after_training_single_run
def test_download_non_existent_file(test_output_dirs: OutputFolderForTests) -> None:
def test_download_non_existing_file_in_crossval(test_output_dirs: OutputFolderForTests) -> None:
"""
Trying to download a non-existing file when doing cross validation should raise an exception.
Downloading a non-existing file when trying to load cross validation results
should not raise an exception.
"""
run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
config = PlotCrossValidationConfig(run_recovery_id=None,
Expand All @@ -474,21 +474,19 @@ def test_download_non_existent_file(test_output_dirs: OutputFolderForTests) -> N
should_validate=False)
config.outputs_directory = test_output_dirs.root_dir
does_not_exist = "does_not_exist.txt"
with pytest.raises(ValueError) as ex:
config.download_or_get_local_file(run,
blob_to_download=does_not_exist,
destination=test_output_dirs.root_dir)
assert does_not_exist in str(ex)
assert "Unable to download file" in str(ex)
result = config.download_or_get_local_file(run,
blob_to_download=does_not_exist,
destination=test_output_dirs.root_dir)
assert result is None


@pytest.mark.after_training_hello_container
def test_model_inference_on_single_run(test_output_dirs: OutputFolderForTests) -> None:
fallback_run_id_for_local_execution = FALLBACK_HELLO_CONTAINER_RUN
falllback_run_id = FALLBACK_HELLO_CONTAINER_RUN

files_to_check = ["test_mse.txt", "test_mae.txt"]

training_run = get_most_recent_run(fallback_run_id_for_local_execution=fallback_run_id_for_local_execution)
training_run = get_most_recent_run(fallback_run_id_for_local_execution=falllback_run_id)
all_training_files = training_run.get_file_names()
for file in files_to_check:
assert f"outputs/{file}" in all_training_files, f"{file} is missing"
Expand All @@ -500,7 +498,7 @@ def test_model_inference_on_single_run(test_output_dirs: OutputFolderForTests) -

container = HelloContainer()
container.set_output_to(test_output_dirs.root_dir)
container.model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=fallback_run_id_for_local_execution)
container.model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=falllback_run_id)
azure_config = get_default_azure_config()
azure_config.train = False
ml_runner = MLRunner(container=container, azure_config=azure_config, project_root=test_output_dirs.root_dir)
Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines/build-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ jobs:
- job: TrainEnsemble
variables:
- name: model
value: 'BasicModel2Epochs'
value: 'BasicModelForEnsembleTest'
- name: number_of_cross_validation_splits
value: 2
- name: tag
Expand Down