diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 342c8dfc7..b3778fa5b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,9 +27,3 @@ repos: rev: v1.5.7 hooks: - id: autopep8 - -- repo: https://github.com/ambv/black - rev: 21.9b0 - hooks: - - id: black - language_version: python3.7 diff --git a/CHANGELOG.md b/CHANGELOG.md index 062cddca2..8a709138c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ jobs that run in AzureML. - ([#554](https://github.com/microsoft/InnerEye-DeepLearning/pull/554)) Added a parameter `pretraining_dataset_id` to `NIH_COVID_BYOL` to specify the name of the SSL training dataset. - ([#560](https://github.com/microsoft/InnerEye-DeepLearning/pull/560)) Added pre-commit hooks. +-([#619](https://github.com/microsoft/InnerEye-DeepLearning/pull/619)) Add DeepMIL PANDA - ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets. - ([#589](https://github.com/microsoft/InnerEye-DeepLearning/pull/589)) Add `LightningContainer.update_azure_config()` hook to enable overriding `AzureConfig` parameters from a container (e.g. `experiment_name`, `cluster`, `num_nodes`). diff --git a/InnerEye/Azure/azure_util.py b/InnerEye/Azure/azure_util.py index 7dc20d534..856c9f747 100644 --- a/InnerEye/Azure/azure_util.py +++ b/InnerEye/Azure/azure_util.py @@ -52,14 +52,20 @@ def split_recovery_id(id: str) -> Tuple[str, str]: """ components = id.strip().split(EXPERIMENT_RUN_SEPARATOR) if len(components) > 2: - raise ValueError("recovery_id must be in the format: 'experiment_name:run_id', but got: {}".format(id)) + raise ValueError( + "recovery_id must be in the format: 'experiment_name:run_id', but got: {}".format( + id + ) + ) elif len(components) == 2: return components[0], components[1] else: recovery_id_regex = r"^(\w+)_\d+_[0-9a-f]+$|^(\w+)_\d+$" match = re.match(recovery_id_regex, id) if not match: - raise ValueError("The recovery ID was not in the expected format: {}".format(id)) + raise ValueError( + "The recovery ID was not in the expected format: {}".format(id) + ) return (match.group(1) or match.group(2)), id @@ -77,9 +83,15 @@ def fetch_run(workspace: Workspace, run_recovery_id: str) -> Run: try: experiment_to_recover = Experiment(workspace, experiment) except Exception as ex: - raise Exception(f"Unable to retrieve run {run} in experiment {experiment}: {str(ex)}") + raise Exception( + f"Unable to retrieve run {run} in experiment {experiment}: {str(ex)}" + ) run_to_recover = fetch_run_for_experiment(experiment_to_recover, run) - logging.info("Fetched run #{} {} from experiment {}.".format(run, run_to_recover.number, experiment)) + logging.info( + "Fetched run #{} {} from experiment {}.".format( + run, run_to_recover.number, experiment + ) + ) return run_to_recover @@ -94,9 +106,13 @@ def fetch_run_for_experiment(experiment_to_recover: Experiment, run_id: str) -> except Exception: available_runs = experiment_to_recover.get_runs() available_ids = ", ".join([run.id for run in available_runs]) - raise (Exception( - "Run {} not found for experiment: {}. Available runs are: {}".format( - run_id, experiment_to_recover.name, available_ids))) + raise ( + Exception( + "Run {} not found for experiment: {}. Available runs are: {}".format( + run_id, experiment_to_recover.name, available_ids + ) + ) + ) def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]: @@ -116,8 +132,11 @@ def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]: return exp_runs -def fetch_child_runs(run: Run, status: Optional[str] = None, - expected_number_cross_validation_splits: int = 0) -> List[Run]: +def fetch_child_runs( + run: Run, + status: Optional[str] = None, + expected_number_cross_validation_splits: int = 0, +) -> List[Run]: """ Fetch child runs for the provided runs that have the provided AML status (or fetch all by default) and have a run_recovery_id tag value set (this is to ignore superfluous AML infrastructure platform runs). @@ -138,18 +157,25 @@ def fetch_child_runs(run: Run, status: Optional[str] = None, if 0 < expected_number_cross_validation_splits != len(children_runs): logging.warning( f"The expected number of child runs was {expected_number_cross_validation_splits}." - f"Fetched only: {len(children_runs)} runs. Now trying to fetch them manually.") - run_ids_to_evaluate = [f"{create_run_recovery_id(run)}_{i}" - for i in range(expected_number_cross_validation_splits)] - children_runs = [fetch_run(run.experiment.workspace, id) for id in run_ids_to_evaluate] + f"Fetched only: {len(children_runs)} runs. Now trying to fetch them manually." + ) + run_ids_to_evaluate = [ + f"{create_run_recovery_id(run)}_{i}" + for i in range(expected_number_cross_validation_splits) + ] + children_runs = [ + fetch_run(run.experiment.workspace, id) for id in run_ids_to_evaluate + ] if status is not None: - children_runs = [child_run for child_run in children_runs if child_run.get_status() == status] + children_runs = [ + child_run for child_run in children_runs if child_run.get_status() == status + ] return children_runs def is_ensemble_run(run: Run) -> bool: """Checks if the run was an ensemble of multiple models""" - return run.get_tags().get(IS_ENSEMBLE_KEY_NAME) == 'True' + return run.get_tags().get(IS_ENSEMBLE_KEY_NAME) == "True" def to_azure_friendly_string(x: Optional[str]) -> Optional[str]: @@ -160,7 +186,7 @@ def to_azure_friendly_string(x: Optional[str]) -> Optional[str]: if x is None: return x else: - return re.sub('_+', '_', re.sub(r'\W+', '_', x)) + return re.sub("_+", "_", re.sub(r"\W+", "_", x)) def to_azure_friendly_container_path(path: Path) -> str: @@ -178,7 +204,7 @@ def is_offline_run_context(run_context: Run) -> bool: :param run_context: Context of the run to check :return: """ - return not hasattr(run_context, 'experiment') + return not hasattr(run_context, "experiment") def get_run_context_or_default(run: Optional[Run] = None) -> Run: @@ -199,7 +225,12 @@ def get_cross_validation_split_index(run: Run) -> int: if is_offline_run_context(run): return DEFAULT_CROSS_VALIDATION_SPLIT_INDEX else: - return int(run.get_tags().get(CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, DEFAULT_CROSS_VALIDATION_SPLIT_INDEX)) + return int( + run.get_tags().get( + CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, + DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, + ) + ) def is_cross_validation_child_run(run: Run) -> bool: @@ -256,9 +287,7 @@ def is_parent_run(run: Run) -> bool: return PARENT_RUN_CONTEXT and run.id == PARENT_RUN_CONTEXT.id -def download_run_output_file(blob_path: Path, - destination: Path, - run: Run) -> Path: +def download_run_output_file(blob_path: Path, destination: Path, run: Run) -> Path: """ Downloads a single file from the run's default output directory: DEFAULT_AML_UPLOAD_DIR ("outputs"). For example, if blobs_path = "foo/bar.csv", then the run result file "outputs/foo/bar.csv" will be downloaded @@ -270,17 +299,21 @@ def download_run_output_file(blob_path: Path, """ blobs_prefix = str((fixed_paths.DEFAULT_AML_UPLOAD_DIR / blob_path).as_posix()) destination = destination / blob_path.name - logging.info(f"Downloading single file from run {run.id}: {blobs_prefix} -> {str(destination)}") + logging.info( + f"Downloading single file from run {run.id}: {blobs_prefix} -> {str(destination)}" + ) try: run.download_file(blobs_prefix, str(destination), _validate_checksum=True) except Exception as ex: - raise ValueError(f"Unable to download file '{blobs_prefix}' from run {run.id}") from ex + raise ValueError( + f"Unable to download file '{blobs_prefix}' from run {run.id}" + ) from ex return destination -def download_run_outputs_by_prefix(blobs_prefix: Path, - destination: Path, - run: Run) -> None: +def download_run_outputs_by_prefix( + blobs_prefix: Path, destination: Path, run: Run +) -> None: """ Download all the blobs from the run's default output directory: DEFAULT_AML_UPLOAD_DIR ("outputs") that have a given prefix (folder structure). When saving, the prefix string will be stripped off. For example, @@ -291,7 +324,9 @@ def download_run_outputs_by_prefix(blobs_prefix: Path, :param destination: Local path to save the downloaded blobs to. """ prefix_str = str((fixed_paths.DEFAULT_AML_UPLOAD_DIR / blobs_prefix).as_posix()) - logging.info(f"Downloading multiple files from run {run.id}: {prefix_str} -> {str(destination)}") + logging.info( + f"Downloading multiple files from run {run.id}: {prefix_str} -> {str(destination)}" + ) # There is a download_files function, but that can time out when downloading several large checkpoints file # (120sec timeout for all files). for file in run.get_file_names(): @@ -300,10 +335,14 @@ def download_run_outputs_by_prefix(blobs_prefix: Path, if target_path.startswith("/"): target_path = target_path[1:] logging.info(f"Downloading {file}") - run.download_file(file, str(destination / target_path), _validate_checksum=True) + run.download_file( + file, str(destination / target_path), _validate_checksum=True + ) else: - logging.warning(f"Skipping file {file}, because the desired prefix {prefix_str} is not aligned with " - f"the folder structure") + logging.warning( + f"Skipping file {file}, because the desired prefix {prefix_str} is not aligned with " + f"the folder structure" + ) def is_running_on_azure_agent() -> bool: @@ -314,10 +353,9 @@ def is_running_on_azure_agent() -> bool: return bool(os.environ.get("AGENT_OS", None)) -def get_comparison_baseline_paths(outputs_folder: Path, - blob_path: Path, run: Run, - dataset_csv_file_name: str) -> \ - Tuple[Optional[Path], Optional[Path]]: +def get_comparison_baseline_paths( + outputs_folder: Path, blob_path: Path, run: Run, dataset_csv_file_name: str +) -> Tuple[Optional[Path], Optional[Path]]: run_rec_id = run.id # We usually find dataset.csv in the same directory as metrics.csv, but we sometimes # have to look higher up. @@ -328,21 +366,29 @@ def get_comparison_baseline_paths(outputs_folder: Path, for blob_path_parent in step_up_directories(blob_path): try: comparison_dataset_path = download_run_output_file( - blob_path_parent / dataset_csv_file_name, destination_folder, run) + blob_path_parent / dataset_csv_file_name, destination_folder, run + ) break except (ValueError, UserErrorException): - logging.warning(f"cannot find {dataset_csv_file_name} at {blob_path_parent} in {run_rec_id}") + logging.warning( + f"cannot find {dataset_csv_file_name} at {blob_path_parent} in {run_rec_id}" + ) except NotADirectoryError: logging.warning(f"{blob_path_parent} is not a directory") break if comparison_dataset_path is None: - logging.warning(f"cannot find {dataset_csv_file_name} at or above {blob_path} in {run_rec_id}") + logging.warning( + f"cannot find {dataset_csv_file_name} at or above {blob_path} in {run_rec_id}" + ) # Look for epoch_NNN/Test/metrics.csv try: comparison_metrics_path = download_run_output_file( - blob_path / SUBJECT_METRICS_FILE_NAME, destination_folder, run) + blob_path / SUBJECT_METRICS_FILE_NAME, destination_folder, run + ) except (ValueError, UserErrorException): - logging.warning(f"cannot find {SUBJECT_METRICS_FILE_NAME} at {blob_path} in {run_rec_id}") + logging.warning( + f"cannot find {SUBJECT_METRICS_FILE_NAME} at {blob_path} in {run_rec_id}" + ) return (comparison_dataset_path, comparison_metrics_path) diff --git a/InnerEye/ML/Histopathology/datamodules/panda_module.py b/InnerEye/ML/Histopathology/datamodules/panda_module.py index c28073a25..64b974ca6 100644 --- a/InnerEye/ML/Histopathology/datamodules/panda_module.py +++ b/InnerEye/ML/Histopathology/datamodules/panda_module.py @@ -3,7 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Tuple +from typing import Tuple, Any from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset @@ -15,6 +15,9 @@ class PandaTilesDataModule(TilesDataModule): Method get_splits() returns the train, val, test splits from the PANDA dataset """ + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + def get_splits(self) -> Tuple[PandaTilesDataset, PandaTilesDataset, PandaTilesDataset]: dataset = PandaTilesDataset(self.root_path) splits = DatasetSplits.from_proportions(dataset.dataset_df.reset_index(), diff --git a/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py b/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py index 43520a8c9..30097896b 100644 --- a/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py +++ b/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py @@ -30,10 +30,10 @@ class PandaTilesDataset(TilesDataset): SPLIT_COLUMN = None # PANDA does not have an official train/test split N_CLASSES = 6 - _RELATIVE_ROOT_FOLDER = "PANDA_tiles_20210926-135446/panda_tiles_level1_224" + _RELATIVE_ROOT_FOLDER = Path("PANDA_tiles_20210926-135446/panda_tiles_level1_224") def __init__(self, - root: Union[str, Path], + root: Path, dataset_csv: Optional[Union[str, Path]] = None, dataset_df: Optional[pd.DataFrame] = None) -> None: super().__init__(root=Path(root) / self._RELATIVE_ROOT_FOLDER, @@ -48,7 +48,7 @@ class PandaTilesDatasetReturnImageLabel(VisionDataset): class label. """ def __init__(self, - root: Union[str, Path], + root: Path, dataset_csv: Optional[Union[str, Path]] = None, dataset_df: Optional[pd.DataFrame] = None, transform: Optional[Callable] = None, diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index c9af8c08d..e42eacb84 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -3,6 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import logging from pathlib import Path import pandas as pd import numpy as np @@ -166,7 +167,7 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK bag_labels_list = [] bag_logits_list = [] bag_attn_list = [] - for bag_idx in range(len(batch[TilesDataset.LABEL_COLUMN])): + for bag_idx in range(len(batch[self.label_column])): images = batch[TilesDataset.IMAGE_COLUMN][bag_idx] labels = batch[self.label_column][bag_idx] bag_labels_list.append(self.get_bag_label(labels)) @@ -177,7 +178,7 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK bag_labels = torch.stack(bag_labels_list).view(-1) if self.n_classes > 1: - loss = self.loss_fn(bag_logits, bag_labels) + loss = self.loss_fn(bag_logits, bag_labels.long()) else: loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float()) @@ -201,6 +202,14 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK ResultsKey.PROB: probs, ResultsKey.PRED_LABEL: preds, ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list, ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]}) + + if (TilesDataset.TILE_X_COLUMN in batch.keys()) and (TilesDataset.TILE_Y_COLUMN in batch.keys()): + results.update({ResultsKey.TILE_X: batch[TilesDataset.TILE_X_COLUMN], + ResultsKey.TILE_Y: batch[TilesDataset.TILE_Y_COLUMN]} + ) + else: + logging.warning("Coordinates not found in batch. If this is not expected check your input tiles dataset.") + return results def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore diff --git a/InnerEye/ML/Histopathology/utils/naming.py b/InnerEye/ML/Histopathology/utils/naming.py index 32d46d54d..1af40015b 100644 --- a/InnerEye/ML/Histopathology/utils/naming.py +++ b/InnerEye/ML/Histopathology/utils/naming.py @@ -15,3 +15,6 @@ class ResultsKey(str, Enum): PRED_LABEL = 'pred_label' TRUE_LABEL = 'true_label' BAG_ATTN = 'bag_attn' + TILE_X = "x" + TILE_Y = "y" + diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py index a2c1900f0..df2394f68 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py @@ -23,11 +23,22 @@ from InnerEye.Common import fixed_paths from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule -from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import TcgaCrckTilesDataModule -from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, ImageNetEncoder, - ImageNetSimCLREncoder, InnerEyeSSLEncoder) -from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTilesBatchd -from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset +from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import ( + TcgaCrckTilesDataModule, +) +from InnerEye.ML.Histopathology.models.encoders import ( + HistoSSLEncoder, + ImageNetEncoder, + ImageNetSimCLREncoder, + InnerEyeSSLEncoder, +) +from InnerEye.ML.Histopathology.models.transforms import ( + EncodeTilesBatchd, + LoadTilesBatchd, +) +from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import ( + TcgaCrck_TilesDataset, +) class DeepSMILECrck(BaseMIL): @@ -36,21 +47,17 @@ def __init__(self, **kwargs: Any) -> None: default_kwargs = dict( # declared in BaseMIL: pooling_type=GatedAttentionLayer.__name__, - # declared in DatasetParams: local_dataset=Path("/tmp/datasets/TCGA-CRCk"), azure_dataset_id="TCGA-CRCk", # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI - # declared in TrainerParams: num_epochs=16, recovery_checkpoint_save_interval=16, recovery_checkpoints_save_last_k=-1, - # declared in WorkflowParams: number_of_cross_validation_splits=5, cross_validation_split_index=0, - # declared in OptimizerParams: l_rate=5e-4, weight_decay=1e-4, @@ -60,33 +67,45 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**default_kwargs) self.best_checkpoint_filename = "checkpoint_max_val_auroc" - self.best_checkpoint_filename_with_suffix = self.best_checkpoint_filename + ".ckpt" + self.best_checkpoint_filename_with_suffix = ( + self.best_checkpoint_filename + ".ckpt" + ) self.checkpoint_folder_path = "outputs/checkpoints/" - best_checkpoint_callback = ModelCheckpoint(dirpath=self.checkpoint_folder_path, - monitor='val/auroc', - filename=self.best_checkpoint_filename, - auto_insert_metric_name=False, - mode='max') + best_checkpoint_callback = ModelCheckpoint( + dirpath=self.checkpoint_folder_path, + monitor="val/auroc", + filename=self.best_checkpoint_filename, + auto_insert_metric_name=False, + mode="max", + ) self.callbacks = best_checkpoint_callback @property def cache_dir(self) -> Path: - return Path(f"/tmp/innereye_cache/{self.__class__.__name__}-{self.encoder_type}/") + return Path( + f"/tmp/innereye_cache1/{self.__class__.__name__}-{self.encoder_type}/" + ) def get_data_module(self) -> TilesDataModule: image_key = TcgaCrck_TilesDataset.IMAGE_COLUMN - transform = Compose([LoadTilesBatchd(image_key, progress=True), - EncodeTilesBatchd(image_key, self.encoder)]) - return TcgaCrckTilesDataModule(root_path=self.local_dataset, - max_bag_size=self.max_bag_size, - batch_size=self.batch_size, - transform=transform, - cache_mode=self.cache_mode, - save_precache=self.save_precache, - cache_dir=self.cache_dir, - number_of_cross_validation_splits=self.number_of_cross_validation_splits, - cross_validation_split_index=self.cross_validation_split_index) + transform = Compose( + [ + LoadTilesBatchd(image_key, progress=True), + EncodeTilesBatchd(image_key, self.encoder), + ] + ) + return TcgaCrckTilesDataModule( + root_path=self.local_dataset, + max_bag_size=self.max_bag_size, + batch_size=self.batch_size, + transform=transform, + cache_mode=self.cache_mode, + save_precache=self.save_precache, + cache_dir=self.cache_dir, + number_of_cross_validation_splits=self.number_of_cross_validation_splits, + cross_validation_split_index=self.cross_validation_split_index, + ) def get_trainer_arguments(self) -> Dict[str, Any]: # These arguments will be passed through to the Lightning trainer. @@ -98,7 +117,11 @@ def get_path_to_best_checkpoint(self) -> Path: was applied there. """ # absolute path is required for registering the model. - return fixed_paths.repository_root_directory() / self.checkpoint_folder_path / self.best_checkpoint_filename_with_suffix + return ( + fixed_paths.repository_root_directory() + / self.checkpoint_folder_path + / self.best_checkpoint_filename_with_suffix + ) class TcgaCrckImageNetMIL(DeepSMILECrck): diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py new file mode 100644 index 000000000..17f61921f --- /dev/null +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -0,0 +1,145 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +from typing import Any, Dict +from pathlib import Path +import os +from monai.transforms import Compose +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + +from health_azure.utils import CheckpointDownloader +from health_azure.utils import get_workspace +from health_ml.networks.layers.attention_layers import GatedAttentionLayer +from InnerEye.Common import fixed_paths +from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule +from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset + +from InnerEye.ML.Histopathology.models.transforms import ( + EncodeTilesBatchd, + LoadTilesBatchd, +) +from InnerEye.ML.Histopathology.models.encoders import ( + HistoSSLEncoder, + ImageNetEncoder, + ImageNetSimCLREncoder, + InnerEyeSSLEncoder, +) +from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL +from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint + + +class DeepSMILEPanda(BaseMIL): + def __init__(self, **kwargs: Any) -> None: + default_kwargs = dict( + # declared in BaseMIL: + pooling_type=GatedAttentionLayer.__name__, + # declared in DatasetParams: + local_dataset=Path("/tmp/datasets/PANDA_tiles"), + azure_dataset_id="PANDA_tiles", + # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI + # declared in TrainerParams: + num_epochs=200, + recovery_checkpoint_save_interval=10, + recovery_checkpoints_save_last_k=-1, + # declared in WorkflowParams: + number_of_cross_validation_splits=5, + cross_validation_split_index=0, + # declared in OptimizerParams: + l_rate=5e-4, + weight_decay=1e-4, + adam_betas=(0.9, 0.99), + ) + default_kwargs.update(kwargs) + super().__init__(**default_kwargs) + super().__init__(**default_kwargs) + self.best_checkpoint_filename = "checkpoint_max_val_auroc" + self.best_checkpoint_filename_with_suffix = ( + self.best_checkpoint_filename + ".ckpt" + ) + self.checkpoint_folder_path = "outputs/checkpoints/" + best_checkpoint_callback = ModelCheckpoint( + dirpath=self.checkpoint_folder_path, + monitor="val/accuracy", + filename=self.best_checkpoint_filename, + auto_insert_metric_name=False, + mode="max", + ) + self.callbacks = best_checkpoint_callback + + @property + def cache_dir(self) -> Path: + return Path( + f"/tmp/innereye_cache1/{self.__class__.__name__}-{self.encoder_type}/" + ) + + def setup(self) -> None: + if self.encoder_type == InnerEyeSSLEncoder.__name__: + self.downloader = CheckpointDownloader( + azure_config_json_path=get_workspace(), + run_recovery_id=innereye_ssl_checkpoint, + checkpoint_filename="last.ckpt", + download_dir="outputs/", + ) + os.chdir(fixed_paths.repository_root_directory()) + self.downloader.download_checkpoint_if_necessary() + self.encoder = self.get_encoder() + self.encoder.cuda() + self.encoder.eval() + + def get_data_module(self) -> PandaTilesDataModule: + image_key = PandaTilesDataset.IMAGE_COLUMN + transform = Compose( + [ + LoadTilesBatchd(image_key, progress=True), + EncodeTilesBatchd(image_key, self.encoder), + ] + ) + return PandaTilesDataModule( + root_path=self.local_dataset, + max_bag_size=self.max_bag_size, + batch_size=self.batch_size, + transform=transform, + cache_mode=self.cache_mode, + save_precache=self.save_precache, + cache_dir=self.cache_dir, + number_of_cross_validation_splits=self.number_of_cross_validation_splits, + cross_validation_split_index=self.cross_validation_split_index, + ) + + def get_trainer_arguments(self) -> Dict[str, Any]: + # These arguments will be passed through to the Lightning trainer. + return {"callbacks": self.callbacks} + + def get_path_to_best_checkpoint(self) -> Path: + """ + Returns the full path to a checkpoint file that was found to be best during training, whatever criterion + was applied there. + """ + # absolute path is required for registering the model. + return ( + fixed_paths.repository_root_directory() + / self.checkpoint_folder_path + / self.best_checkpoint_filename_with_suffix + ) + + +class PandaImageNetMIL(DeepSMILEPanda): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs) + + +class PandaImageNetSimCLRMIL(DeepSMILEPanda): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=ImageNetSimCLREncoder.__name__, **kwargs) + + +class PandaInnerEyeSSLMIL(DeepSMILEPanda): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=InnerEyeSSLEncoder.__name__, **kwargs) + + +class PandaHistoSSLMIL(DeepSMILEPanda): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=HistoSSLEncoder.__name__, **kwargs) diff --git a/InnerEye/ML/configs/histo_configs/run_ids.py b/InnerEye/ML/configs/histo_configs/run_ids.py new file mode 100644 index 000000000..0c9ca2163 --- /dev/null +++ b/InnerEye/ML/configs/histo_configs/run_ids.py @@ -0,0 +1 @@ +innereye_ssl_checkpoint = "hsharma_panda_explore:hsharma_panda_explore_1638437076_357167ae" diff --git a/Tests/ML/histopathology/models/test_deepmil.py b/Tests/ML/histopathology/models/test_deepmil.py index 60df2f759..797b3080c 100644 --- a/Tests/ML/histopathology/models/test_deepmil.py +++ b/Tests/ML/histopathology/models/test_deepmil.py @@ -4,17 +4,29 @@ # ------------------------------------------------------------------------------------------ import os -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Type # noqa import pytest from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose from torchvision.models import resnet18 -from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer - -from InnerEye.ML.configs.histo_configs.classification.DeepSMILECrck import DeepSMILECrck +from health_ml.networks.layers.attention_layers import ( + AttentionLayer, + GatedAttentionLayer, +) + +from InnerEye.ML.lightning_container import LightningContainer +from InnerEye.ML.configs.histo_configs.classification.DeepSMILECrck import ( + DeepSMILECrck, +) +from InnerEye.ML.configs.histo_configs.classification.DeepSMILEPanda import ( + DeepSMILEPanda, +) from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule -from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR +from InnerEye.ML.Histopathology.datasets.default_paths import ( + TCGA_CRCK_DATASET_DIR, + PANDA_TILES_DATASET_DIR, +) from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder, TileEncoder from InnerEye.ML.Histopathology.utils.naming import ResultsKey @@ -30,23 +42,27 @@ def get_supervised_imagenet_encoder() -> TileEncoder: @pytest.mark.parametrize("max_bag_size", [1, 7]) @pytest.mark.parametrize("pool_hidden_dim", [1, 5]) @pytest.mark.parametrize("pool_out_dim", [1, 6]) -def test_lightningmodule(n_classes: int, - pooling_layer: Callable[[int, int, int], nn.Module], - batch_size: int, - max_bag_size: int, - pool_hidden_dim: int, - pool_out_dim: int) -> None: +def test_lightningmodule( + n_classes: int, + pooling_layer: Callable[[int, int, int], nn.Module], + batch_size: int, + max_bag_size: int, + pool_hidden_dim: int, + pool_out_dim: int, +) -> None: assert n_classes > 0 # hard-coded here to avoid test explosion; correctness of other encoders is tested elsewhere encoder = get_supervised_imagenet_encoder() - module = DeepMILModule(encoder=encoder, - label_column='label', - n_classes=n_classes, - pooling_layer=pooling_layer, - pool_hidden_dim=pool_hidden_dim, - pool_out_dim=pool_out_dim) + module = DeepMILModule( + encoder=encoder, + label_column="label", + n_classes=n_classes, + pooling_layer=pooling_layer, + pool_hidden_dim=pool_hidden_dim, + pool_out_dim=pool_out_dim, + ) bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim]) bag_labels_list = [] @@ -56,7 +72,7 @@ def test_lightningmodule(n_classes: int, if n_classes > 1: labels = randint(n_classes, size=(max_bag_size,)) else: - labels = randint(n_classes+1, size=(max_bag_size,)) + labels = randint(n_classes + 1, size=(max_bag_size,)) bag_labels_list.append(module.get_bag_label(labels)) logit, attn = module(bag) assert logit.shape == (1, n_classes) @@ -92,29 +108,47 @@ def test_lightningmodule(n_classes: int, assert preds.shape[0] == batch_size for metric_name, metric_object in module.train_metrics.items(): - if (batch_size > 1) or (not metric_name == 'auroc'): + if (batch_size > 1) or (not metric_name == "auroc"): score = metric_object(preds.view(-1, 1), bag_labels.view(-1, 1)) assert score >= 0 and score <= 1 def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict: - device = 'cuda' if use_gpu else 'cpu' - return {key: [value.to(device) if isinstance(value, Tensor) else value for value in values] - for key, values in batch.items()} + device = "cuda" if use_gpu else "cpu" + return { + key: [ + value.to(device) if isinstance(value, Tensor) else value for value in values + ] + for key, values in batch.items() + } + + +CONTAINER_DATASET_DIR = { + DeepSMILEPanda: PANDA_TILES_DATASET_DIR, + DeepSMILECrck: TCGA_CRCK_DATASET_DIR, +} +@pytest.mark.parametrize("container_type", [DeepSMILEPanda, + DeepSMILECrck]) @pytest.mark.parametrize("use_gpu", [True, False]) -def test_container(use_gpu: bool) -> None: - container_type = DeepSMILECrck - dataset_dir = TCGA_CRCK_DATASET_DIR +def test_container(container_type: Type[LightningContainer], use_gpu: bool) -> None: + dataset_dir = CONTAINER_DATASET_DIR[container_type] if not os.path.isdir(dataset_dir): - pytest.skip(f"Dataset for container {container_type.__name__} " - f"is unavailable: {dataset_dir}") + pytest.skip( + f"Dataset for container {container_type.__name__} " + f"is unavailable: {dataset_dir}" + ) + if container_type is DeepSMILECrck: + container = DeepSMILECrck(encoder_type=ImageNetEncoder.__name__) + elif container_type is DeepSMILEPanda: + container = DeepSMILEPanda(encoder_type=ImageNetEncoder.__name__) + else: + container = container_type() - container = DeepSMILECrck(encoder_type=ImageNetEncoder.__name__) container.setup() - data_module: TilesDataModule = container.get_data_module() + data_module: TilesDataModule = container.get_data_module() # type: ignore data_module.max_bag_size = 10 module = container.create_model() if use_gpu: @@ -135,7 +169,7 @@ def test_container(use_gpu: bool) -> None: for batch_idx, batch in enumerate(val_data_loader): batch = move_batch_to_expected_device(batch, use_gpu) loss = module.validation_step(batch, batch_idx) - assert loss.shape == () + assert loss.shape == () # noqa assert isinstance(loss, Tensor) break @@ -143,7 +177,7 @@ def test_container(use_gpu: bool) -> None: for batch_idx, batch in enumerate(test_data_loader): batch = move_batch_to_expected_device(batch, use_gpu) outputs_dict = module.test_step(batch, batch_idx) - loss = outputs_dict[ResultsKey.LOSS] + loss = outputs_dict[ResultsKey.LOSS] # noqa assert loss.shape == () assert isinstance(loss, Tensor) break @@ -152,22 +186,24 @@ def test_container(use_gpu: bool) -> None: def test_class_weights_binary() -> None: class_weights = Tensor([0.5, 3.5]) n_classes = 1 - module = DeepMILModule(encoder=get_supervised_imagenet_encoder(), - label_column='label', - n_classes=n_classes, - pooling_layer=AttentionLayer, - pool_hidden_dim=5, - pool_out_dim=1, - class_weights=class_weights) + module = DeepMILModule( + encoder=get_supervised_imagenet_encoder(), + label_column="label", + n_classes=n_classes, + pooling_layer=AttentionLayer, + pool_hidden_dim=5, + pool_out_dim=1, + class_weights=class_weights, + ) logits = Tensor(randn(1, n_classes)) - bag_label = randint(n_classes+1, size=(1,)) + bag_label = randint(n_classes + 1, size=(1,)) - pos_weight = Tensor([class_weights[1]/(class_weights[0]+1e-5)]) + pos_weight = Tensor([class_weights[1] / (class_weights[0] + 1e-5)]) loss_weighted = module.loss_fn(logits.squeeze(1), bag_label.float()) criterion_unweighted = nn.BCEWithLogitsLoss() loss_unweighted = criterion_unweighted(logits.squeeze(1), bag_label.float()) if bag_label.item() == 1: - assert allclose(loss_weighted, pos_weight*loss_unweighted) + assert allclose(loss_weighted, pos_weight * loss_unweighted) else: assert allclose(loss_weighted, loss_unweighted) @@ -175,13 +211,15 @@ def test_class_weights_binary() -> None: def test_class_weights_multiclass() -> None: class_weights = Tensor([0.33, 0.33, 0.33]) n_classes = 3 - module = DeepMILModule(encoder=get_supervised_imagenet_encoder(), - label_column='label', - n_classes=n_classes, - pooling_layer=AttentionLayer, - pool_hidden_dim=5, - pool_out_dim=1, - class_weights=class_weights) + module = DeepMILModule( + encoder=get_supervised_imagenet_encoder(), + label_column="label", + n_classes=n_classes, + pooling_layer=AttentionLayer, + pool_hidden_dim=5, + pool_out_dim=1, + class_weights=class_weights, + ) logits = Tensor(randn(1, n_classes)) bag_label = randint(n_classes, size=(1,))