diff --git a/.amlignore b/.amlignore index d39026640..d7097c74e 100644 --- a/.amlignore +++ b/.amlignore @@ -21,7 +21,7 @@ pull_request_template.md SECURITY.md __pycache__ azure-pipelines -datasets +/datasets docs sphinx-docs modelweights @@ -35,4 +35,5 @@ tensorboard_runs InnerEyeTestVariables.txt InnerEyePrivateSettings.yml cifar-10-batches-py -cifar-100-python \ No newline at end of file +cifar-100-python +!**/InnerEye/ML/Histopathology/datasets diff --git a/.gitignore b/.gitignore index 36a4df7a2..834e1d2c7 100644 --- a/.gitignore +++ b/.gitignore @@ -46,7 +46,7 @@ packages-microsoft-prod.deb # PyInstaller # Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. +# before PyInstaller builds the exe, so as to inject date/other infos into it *.manifest *.spec @@ -166,3 +166,5 @@ InnerEye-DataQuality/name_stats_scoring.png InnerEye-DataQuality/cifar-10-batches-py InnerEye-DataQuality/logs InnerEye-DataQuality/data + +!**/InnerEye/ML/Histopathology/datasets \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index a2a6b1f53..623bd23c7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "fastMRI"] path = fastMRI url = https://github.com/facebookresearch/fastMRI +[submodule "hi-ml"] + path = hi-ml + url = https://github.com/microsoft/hi-ml diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c53bef7c..07558c9a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ created. ### Added - ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run. -- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Commandline switch `monitor_gpu` to monitor +- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Commandline switch `monitor_gpu` to monitor GPU utilization via Lightning's `GpuStatsMonitor`, switch `monitor_loading` to check batch loading times via `BatchTimeCallback`, and `pl_profiler` to turn on the Lightning profiler (`simple`, `advanced`, or `pytorch`) - ([#544](https://github.com/microsoft/InnerEye-DeepLearning/pull/544)) Add documentation for segmentation model evaluation. @@ -31,6 +31,8 @@ jobs that run in AzureML. - ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets. - ([#589](https://github.com/microsoft/InnerEye-DeepLearning/pull/589)) Add `LightningContainer.update_azure_config()` hook to enable overriding `AzureConfig` parameters from a container (e.g. `experiment_name`, `cluster`, `num_nodes`). +-([#603](https://github.com/microsoft/InnerEye-DeepLearning/pull/603)) Add histopathology module + ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. diff --git a/InnerEye/ML/Histopathology/datamodules/base_module.py b/InnerEye/ML/Histopathology/datamodules/base_module.py new file mode 100644 index 000000000..4bf4557e2 --- /dev/null +++ b/InnerEye/ML/Histopathology/datamodules/base_module.py @@ -0,0 +1,149 @@ +import pickle +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +from monai.data.dataset import CacheDataset, Dataset, PersistentDataset +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader + +from health_ml.utils.bag_utils import BagDataset, multibag_collate +from health_ml.utils.common_utils import _create_generator +from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset +from InnerEye.ML.Histopathology.models.transforms import LoadTilesBatchd + + +class CacheMode(Enum): + NONE = 'none' + MEMORY = 'memory' + DISK = 'disk' + + +class TilesDataModule(LightningDataModule): + """Base class to load the tiles of a dataset as train, val, test sets""" + + def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1, + seed: Optional[int] = None, transform: Optional[Callable] = None, + cache_mode: CacheMode = CacheMode.NONE, save_precache: bool = False, + cache_dir: Optional[Path] = None, + number_of_cross_validation_splits: int = 0, + cross_validation_split_index: int = 0) -> None: + """ + :param root_path: Root directory of the source dataset. + :param max_bag_size: Upper bound on number of tiles in each loaded bag. If 0 (default), + will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield + random subsets of instances. + :param batch_size: Number of slides to load per batch. + :param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in + train/val/test splits is handled independently in `get_splits()`. (default: `None`) + :param transform: A transform to apply to the source tiles dataset, or a composition of + transforms using `monai.transforms.Compose`. By default (`None`), applies `LoadTilesBatchd`. + :param cache_mode: The type of caching to perform, i.e. whether the results of all + transforms up to the first randomised one should be computed only once and reused in + subsequent iterations: + - `MEMORY`: the entire transformed dataset is kept in memory for fastest access; + - `DISK`: each transformed sample is saved to disk and loaded on-demand; + - `NONE` (default): no caching is performed. + :param save_precache: Whether to pre-cache the entire transformed dataset upfront and save + it to disk. This is done once in `prepare_data()` only on the local rank-0 process, so + multiple processes can afterwards access the same cache without contention in DDP settings. + :param cache_dir: The directory onto which to cache data if caching is enabled. + :param number_of_cross_validation_splits: Number of folds to perform. + :param cross_validation_split_index: Index of the cross validation split to be performed. + """ + if save_precache and cache_mode is CacheMode.NONE: + raise ValueError("Can only pre-cache if caching is enabled") + if save_precache and cache_dir is None: + raise ValueError("A cache directory is required for pre-caching") + if cache_mode is CacheMode.DISK and cache_dir is None: + raise ValueError("A cache directory is required for on-disk caching") + super().__init__() + + self.root_path = root_path + self.max_bag_size = max_bag_size + self.transform = transform + self.cache_mode = cache_mode + self.save_precache = save_precache + self.cache_dir = cache_dir + self.batch_size = batch_size + self.number_of_cross_validation_splits = number_of_cross_validation_splits + self.cross_validation_split_index = cross_validation_split_index + self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits() + self.class_weights = self.train_dataset.get_class_weights() + self.seed = seed + + def get_splits(self) -> Tuple[TilesDataset, TilesDataset, TilesDataset]: + """Create the training, validation, and test datasets""" + raise NotImplementedError + + def prepare_data(self) -> None: + if self.save_precache: + self._load_dataset(self.train_dataset, stage='train', shuffle=True) + self._load_dataset(self.val_dataset, stage='val', shuffle=True) + self._load_dataset(self.test_dataset, stage='test', shuffle=True) + + def _dataset_pickle_path(self, stage: str) -> Optional[Path]: + if self.cache_dir is None: + return None + return self.cache_dir / f"{stage}_dataset.pkl" + + def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool) -> Dataset: + dataset_pickle_path = self._dataset_pickle_path(stage) + + if dataset_pickle_path and dataset_pickle_path.exists(): + with dataset_pickle_path.open('rb') as f: + return pickle.load(f) + + generator = _create_generator(self.seed) + bag_dataset = BagDataset(tiles_dataset, # type: ignore + bag_ids=tiles_dataset.slide_ids, + max_bag_size=self.max_bag_size, + shuffle_samples=shuffle, + generator=generator) + transform = self.transform or LoadTilesBatchd(tiles_dataset.IMAGE_COLUMN) + + # Save and restore PRNG state for consistency across (pre-)caching options + generator_state = generator.get_state() + transformed_bag_dataset = self._get_transformed_dataset(bag_dataset, transform) # type: ignore + generator.set_state(generator_state) + + if dataset_pickle_path: + dataset_pickle_path.parent.mkdir(parents=True, exist_ok=True) + with dataset_pickle_path.open('wb') as f: + pickle.dump(transformed_bag_dataset, f) + + return transformed_bag_dataset + + def _get_transformed_dataset(self, base_dataset: BagDataset, + transform: Union[Sequence[Callable], Callable]) -> Dataset: + if self.cache_mode is CacheMode.MEMORY: + dataset = CacheDataset(base_dataset, transform, num_workers=1) # type: ignore + elif self.cache_mode is CacheMode.DISK: + dataset = PersistentDataset(base_dataset, transform, cache_dir=self.cache_dir) # type: ignore + if self.save_precache: + import tqdm # TODO: Make optional + + for i in tqdm.trange(len(dataset), desc="Loading dataset"): + dataset[i] # empty loop to pre-compute all transformed samples + else: + dataset = Dataset(base_dataset, transform) # type: ignore + return dataset + + def _get_dataloader(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool, + **dataloader_kwargs: Any) -> DataLoader: + transformed_bag_dataset = self._load_dataset(tiles_dataset, stage=stage, shuffle=shuffle) + bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore + generator = bag_dataset.bag_sampler.generator + return DataLoader(transformed_bag_dataset, batch_size=self.batch_size, + collate_fn=multibag_collate, shuffle=shuffle, generator=generator, + pin_memory=False, # disable pinning as loaded data may already be on GPU + **dataloader_kwargs) + + def train_dataloader(self) -> DataLoader: + return self._get_dataloader(self.train_dataset, 'train', shuffle=True) + + def val_dataloader(self) -> DataLoader: + return self._get_dataloader(self.val_dataset, 'val', shuffle=True) + + def test_dataloader(self) -> DataLoader: + return self._get_dataloader(self.test_dataset, 'test', shuffle=True) diff --git a/InnerEye/ML/Histopathology/datamodules/panda_module.py b/InnerEye/ML/Histopathology/datamodules/panda_module.py new file mode 100644 index 000000000..0ce597377 --- /dev/null +++ b/InnerEye/ML/Histopathology/datamodules/panda_module.py @@ -0,0 +1,23 @@ +from typing import Tuple + +from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule +from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset +from InnerEye.ML.utils.split_dataset import DatasetSplits + + +class PandaTilesDataModule(TilesDataModule): + """ PandaTilesDataModule is the child class of TilesDataModule specific to PANDA dataset + Method get_splits() returns the train, val, test splits from the PANDA dataset + """ + + def get_splits(self) -> Tuple[PandaTilesDataset, PandaTilesDataset, PandaTilesDataset]: + dataset = PandaTilesDataset(self.root_path) + splits = DatasetSplits.from_proportions(dataset.dataset_df.reset_index(), + proportion_train=.8, + proportion_test=.1, + proportion_val=.1, + subject_column=dataset.TILE_ID_COLUMN, + group_column=dataset.SLIDE_ID_COLUMN) + return (PandaTilesDataset(self.root_path, dataset_df=splits.train), + PandaTilesDataset(self.root_path, dataset_df=splits.val), + PandaTilesDataset(self.root_path, dataset_df=splits.test)) diff --git a/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py b/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py new file mode 100644 index 000000000..35a966b66 --- /dev/null +++ b/InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py @@ -0,0 +1,33 @@ +from typing import Tuple, Any + +from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule +from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset +from InnerEye.ML.utils.split_dataset import DatasetSplits + + +class TcgaCrckTilesDataModule(TilesDataModule): + """ TcgaCrckTilesDataModule is the child class of TilesDataModule specific to TCGA-Crck dataset + Method get_splits() returns the train, val, test splits from the TCGA-Crck dataset + Methods train_dataloader(), val_dataloader() and test_dataloader() override the base class methods for bag loading + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + def get_splits(self) -> Tuple[TcgaCrck_TilesDataset, TcgaCrck_TilesDataset, TcgaCrck_TilesDataset]: + trainval_dataset = TcgaCrck_TilesDataset(self.root_path, train=True) + splits = DatasetSplits.from_proportions(trainval_dataset.dataset_df.reset_index(), + proportion_train=0.8, + proportion_test=0.0, + proportion_val=0.2, + subject_column=trainval_dataset.TILE_ID_COLUMN, + group_column=trainval_dataset.SLIDE_ID_COLUMN, + random_seed=5) + + if self.number_of_cross_validation_splits > 1: + # Function get_k_fold_cross_validation_splits() will concatenate train and val splits + splits = splits.get_k_fold_cross_validation_splits(self.number_of_cross_validation_splits)[self.cross_validation_split_index] + + return (TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.train), + TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.val), + TcgaCrck_TilesDataset(self.root_path, train=False)) diff --git a/InnerEye/ML/Histopathology/datasets/base_dataset.py b/InnerEye/ML/Histopathology/datasets/base_dataset.py new file mode 100644 index 000000000..a7a31afe0 --- /dev/null +++ b/InnerEye/ML/Histopathology/datasets/base_dataset.py @@ -0,0 +1,107 @@ +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import numpy as np +import pandas as pd +import torch +from sklearn.utils.class_weight import compute_class_weight +from torch.utils.data import Dataset + + +class TilesDataset(Dataset): + """Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata. + + :param TILE_ID_COLUMN: CSV column name for tile ID. + :param SLIDE_ID_COLUMN: CSV column name for slide ID. + :param IMAGE_COLUMN: CSV column name for relative path to image file. + :param PATH_COLUMN: CSV column name for relative path to image file. Replicated to propagate the path to the batch. + :param LABEL_COLUMN: CSV column name for tile label. + :param SPLIT_COLUMN: CSV column name for train/test split (optional). + :param TILE_X_COLUMN: CSV column name for horizontal tile coordinate (optional). + :param TILE_Y_COLUMN: CSV column name for vertical tile coordinate (optional). + :param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`. + :param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`. + :param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory. + :param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`. + """ + TILE_ID_COLUMN: str = 'tile_id' + SLIDE_ID_COLUMN: str = 'slide_id' + IMAGE_COLUMN: str = 'image' + PATH_COLUMN: str = 'image_path' + LABEL_COLUMN: str = 'label' + SPLIT_COLUMN: Optional[str] = 'split' + TILE_X_COLUMN: Optional[str] = 'tile_x' + TILE_Y_COLUMN: Optional[str] = 'tile_y' + + TRAIN_SPLIT_LABEL: str = 'train' + TEST_SPLIT_LABEL: str = 'test' + + DEFAULT_CSV_FILENAME: str = "dataset.csv" + + N_CLASSES: int = 1 # binary classification by default + + def __init__(self, + root: Union[str, Path], + dataset_csv: Optional[Union[str, Path]] = None, + dataset_df: Optional[pd.DataFrame] = None, + train: Optional[bool] = None) -> None: + """ + :param root: Root directory of the dataset. + :param dataset_csv: Full path to a dataset CSV file, containing at least + `TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read + from `"{root}/{DEFAULT_CSV_FILENAME}"`. + :param dataset_df: A potentially pre-processed dataframe in the same format as would be read + from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`. + :param train: If `True`, loads only the training split (resp. `False` for test split). By + default (`None`), loads the entire dataset as-is. + """ + if self.SPLIT_COLUMN is None and train is not None: + raise ValueError("Train/test split was specified but dataset has no split column") + + self.root_dir = Path(root) + + if dataset_df is not None: + self.dataset_csv = None + else: + self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME + dataset_df = pd.read_csv(self.dataset_csv) + + columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN, self.LABEL_COLUMN, + self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN] + for column in columns: + if column is not None and column not in dataset_df.columns: + raise ValueError(f"Expected column '{column}' not found in the dataframe") + + dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN) + if train is None: + self.dataset_df = dataset_df + else: + split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL + self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split] + + def __len__(self) -> int: + return self.dataset_df.shape[0] + + def __getitem__(self, index: int) -> Dict[str, Any]: + tile_id = self.dataset_df.index[index] + sample = { + self.TILE_ID_COLUMN: tile_id, + **self.dataset_df.loc[tile_id].to_dict() + } + sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN)) + # we're replicating this column because we want to propagate the path to the batch + sample[self.PATH_COLUMN] = sample[self.IMAGE_COLUMN] + return sample + + @property + def slide_ids(self) -> pd.Series: + return self.dataset_df[self.SLIDE_ID_COLUMN] + + def get_slide_labels(self) -> pd.Series: + return self.dataset_df.groupby(self.SLIDE_ID_COLUMN)[self.LABEL_COLUMN].agg(pd.Series.mode) + + def get_class_weights(self) -> torch.Tensor: + slide_labels = self.get_slide_labels() + classes = np.unique(slide_labels) + class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels) + return torch.as_tensor(class_weights) diff --git a/InnerEye/ML/Histopathology/datasets/default_paths.py b/InnerEye/ML/Histopathology/datasets/default_paths.py new file mode 100644 index 000000000..0497a7ecd --- /dev/null +++ b/InnerEye/ML/Histopathology/datasets/default_paths.py @@ -0,0 +1,8 @@ +PANDA_TILES_DATASET_ID = "PANDA_tiles" +TCGA_CRCK_DATASET_ID = "TCGA-CRCk" +TCGA_PRAD_DATASET_ID = "TCGA-PRAD" + +DEFAULT_DATASET_LOCATION = "/tmp/datasets/" +PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID +TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID +TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID diff --git a/InnerEye/ML/Histopathology/datasets/panda_dataset.py b/InnerEye/ML/Histopathology/datasets/panda_dataset.py new file mode 100644 index 000000000..77ad2f58d --- /dev/null +++ b/InnerEye/ML/Histopathology/datasets/panda_dataset.py @@ -0,0 +1,125 @@ +from pathlib import Path +from typing import Any, Dict, Union, Optional + +import pandas as pd +from monai.config import KeysCollection +from monai.data.image_reader import ImageReader, WSIReader +from monai.transforms import MapTransform +from openslide import OpenSlide +from torch.utils.data import Dataset + +from health_ml.utils import box_utils + + +class PandaDataset(Dataset): + """Dataset class for loading files from the PANDA challenge dataset. + + Iterating over this dataset returns a dictionary containing the `'image_id'`, paths to the `'image'` + and `'mask'` files, and the remaining meta-data from the original dataset (`'data_provider'`, + `'isup_grade'`, and `'gleason_score'`). + + Ref.: https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview + """ + def __init__(self, root_dir: Union[str, Path], n_slides: Optional[int] = None, + frac_slides: Optional[float] = None) -> None: + super().__init__() + self.root_dir = Path(root_dir) + self.train_df = pd.read_csv(self.root_dir / "train.csv", index_col='image_id') + if n_slides or frac_slides: + self.train_df = self.train_df.sample(n=n_slides, frac=frac_slides, replace=False, + random_state=1234) + + def __len__(self) -> int: + return self.train_df.shape[0] + + def _get_image_path(self, image_id: str) -> Path: + return self.root_dir / "train_images" / f"{image_id}.tiff" + + def _get_mask_path(self, image_id: str) -> Path: + return self.root_dir / "train_label_masks" / f"{image_id}_mask.tiff" + + def __getitem__(self, index: int) -> Dict: + image_id = self.train_df.index[index] + return { + 'image_id': image_id, + 'image': str(self._get_image_path(image_id).absolute()), + 'mask': str(self._get_mask_path(image_id).absolute()), + **self.train_df.loc[image_id].to_dict() + } + + +# MONAI's convention is that dictionary transforms have a 'd' suffix in the class name +class ReadImaged(MapTransform): + """Basic transform to read image files.""" + def __init__(self, reader: ImageReader, keys: KeysCollection, + allow_missing_keys: bool = False, **kwargs: Any) -> None: + super().__init__(keys, allow_missing_keys=allow_missing_keys) + self.reader = reader + self.kwargs = kwargs + + def __call__(self, data: Dict) -> Dict: + for key in self.keys: + if key in data or not self.allow_missing_keys: + data[key] = self.reader.read(data[key], **self.kwargs) + return data + + +class LoadPandaROId(MapTransform): + """Transform that loads a pathology slide and mask, cropped to the mask bounding box (ROI). + + Operates on dictionaries, replacing the file paths in `image_key` and `mask_key` with the + respective loaded arrays, in (C, H, W) format. Also adds the following meta-data entries: + - `'location'` (tuple): top-right coordinates of the bounding box + - `'size'` (tuple): width and height of the bounding box + - `'level'` (int): chosen magnification level + - `'scale'` (float): corresponding scale, loaded from the file + """ + def __init__(self, reader: WSIReader, image_key: str = 'image', mask_key: str = 'mask', + level: int = 0, margin: int = 0, **kwargs: Any) -> None: + """ + :param reader: And instance of MONAI's `WSIReader`. + :param image_key: Image key in the input and output dictionaries. + :param mask_key: Mask key in the input and output dictionaries. + :param level: Magnification level to load from the raw multi-scale files. + :param margin: Amount in pixels by which to enlarge the estimated bounding box for cropping. + """ + super().__init__([image_key, mask_key], allow_missing_keys=False) + self.reader = reader + self.image_key = image_key + self.mask_key = mask_key + self.level = level + self.margin = margin + self.kwargs = kwargs + + def _get_bounding_box(self, mask_obj: OpenSlide) -> box_utils.Box: + # Estimate bounding box at the lowest resolution (i.e. highest level) + highest_level = mask_obj.level_count - 1 + scale = mask_obj.level_downsamples[highest_level] + mask, _ = self.reader.get_data(mask_obj, level=highest_level) # loaded as RGB PIL image + + foreground_mask = mask[0] > 0 # PANDA segmentation mask is in 'R' channel + bbox = scale * box_utils.get_bounding_box(foreground_mask).add_margin(self.margin) + return bbox + + def __call__(self, data: Dict) -> Dict: + mask_obj: OpenSlide = self.reader.read(data[self.mask_key]) + image_obj: OpenSlide = self.reader.read(data[self.image_key]) + + level0_bbox = self._get_bounding_box(mask_obj) + + # OpenSlide takes absolute location coordinates in the level 0 reference frame, + # but relative region size in pixels at the chosen level + scale = mask_obj.level_downsamples[self.level] + scaled_bbox = level0_bbox / scale + get_data_kwargs = dict(location=(level0_bbox.x, level0_bbox.y), + size=(scaled_bbox.w, scaled_bbox.h), + level=self.level) + mask, _ = self.reader.get_data(mask_obj, **get_data_kwargs) # type: ignore + data[self.mask_key] = mask[:1] # PANDA segmentation mask is in 'R' channel + data[self.image_key], _ = self.reader.get_data(image_obj, **get_data_kwargs) # type: ignore + data.update(get_data_kwargs) + data['scale'] = scale + + mask_obj.close() + image_obj.close() + return data diff --git a/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py b/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py new file mode 100644 index 000000000..e382aae5e --- /dev/null +++ b/InnerEye/ML/Histopathology/datasets/panda_tiles_dataset.py @@ -0,0 +1,76 @@ +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import pandas as pd +from torchvision.datasets.vision import VisionDataset + +from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset +from InnerEye.ML.Histopathology.models.transforms import load_pil_image +from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex + + +class PandaTilesDataset(TilesDataset): + """ + Dataset class for loading PANDA tiles. + + Iterating over this dataset returns a dictionary containing: + - `'slide_id'` (str): parent slide ID (`'image_id'` in the PANDA dataset) + - `'tile_id'` (str) + - `'image'` (`PIL.Image`): RGB tile + - `'mask'` (str): path to mask PNG file + - `'tile_x'`, `'tile_y'` (int): top-right tile coordinates + - `'data_provider'`, `'slide_isup_grade'`, `'slide_gleason_score'` (str): parent slide metadata + """ + LABEL_COLUMN = "slide_isup_grade" + SPLIT_COLUMN = None # PANDA does not have an official train/test split + N_CLASSES = 6 + + _RELATIVE_ROOT_FOLDER = "PANDA_tiles_20210926-135446/panda_tiles_level1_224" + + def __init__(self, + root: Union[str, Path], + dataset_csv: Optional[Union[str, Path]] = None, + dataset_df: Optional[pd.DataFrame] = None) -> None: + super().__init__(root=Path(root) / self._RELATIVE_ROOT_FOLDER, + dataset_csv=dataset_csv, + dataset_df=dataset_df, + train=None) + + +class PandaTilesDatasetReturnImageLabel(VisionDataset): + """ + Any dataset used in SSL needs to return a tuple where the first element is the image and the second is a + class label. + """ + def __init__(self, + root: Union[str, Path], + dataset_csv: Optional[Union[str, Path]] = None, + dataset_df: Optional[pd.DataFrame] = None, + transform: Optional[Callable] = None, + **kwargs: Any) -> None: + super().__init__(root=root, transform=transform) + self.base_dataset = PandaTilesDataset(root=root, + dataset_csv=dataset_csv, + dataset_df=dataset_df) + + def __getitem__(self, index: int) -> Tuple: # type: ignore + sample = self.base_dataset[index] + # TODO change to a meaningful evaluation + image = load_pil_image(sample[self.base_dataset.IMAGE_COLUMN]) + if self.transform: + image = self.transform(image) + return image, 1 + + def __len__(self) -> int: + return len(self.base_dataset) + + +class PandaTilesDatasetWithReturnIndex(InnerEyeDataClassBaseWithReturnIndex, PandaTilesDatasetReturnImageLabel): + """ + Any dataset used in SSL needs to inherit from InnerEyeDataClassBaseWithReturnIndex as well as VisionData. + This class is just a shorthand notation for this double inheritance. Please note that this class needs + to override __getitem__(), this is why we need a separate PandaTilesDatasetReturnImageLabel. + """ + @property + def num_classes(self) -> int: + return 2 diff --git a/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py b/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py new file mode 100644 index 000000000..783a0278a --- /dev/null +++ b/InnerEye/ML/Histopathology/datasets/tcga_crck_tiles_dataset.py @@ -0,0 +1,64 @@ +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union + +import pandas as pd +from torchvision.datasets.vision import VisionDataset + +from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset +from InnerEye.ML.Histopathology.models.transforms import load_pil_image +from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex + + +class TcgaCrck_TilesDataset(TilesDataset): + """Dataset class for loading TCGA-CRCk tiles. + + Iterating over this dataset returns a dictionary containing: + - `'slide_id'` (str): parent slide ID + - `'tile_id'` (str) + - `'image'` (`PIL.Image`): RGB tile + - `'label'` (str): MSS (0) vs MSIMUT (1) + """ + TILE_X_COLUMN = TILE_Y_COLUMN = None # no tile coordinates available + # This dataset conforms to all other defaults in TilesDataset + + +class TcgaCrck_TilesDatasetReturnImageLabel(VisionDataset): + """ + Any dataset used in SSL needs to return a tuple where the first element is the image and the second is a + class label. + """ + def __init__(self, + root: Union[str, Path], + dataset_csv: Optional[Union[str, Path]] = None, + dataset_df: Optional[pd.DataFrame] = None, + train: Optional[bool] = None, + transform: Optional[Callable] = None, + **kwargs: Any) -> None: + super().__init__(root=root, transform=transform) + self.base_dataset = TcgaCrck_TilesDataset(root=root, + dataset_csv=dataset_csv, + dataset_df=dataset_df, + train=train) + + def __getitem__(self, index: int) -> Tuple: # type: ignore + sample = self.base_dataset[index] + # TODO change to a meaningful evaluation + image = load_pil_image(sample[self.base_dataset.IMAGE_COLUMN]) + if self.transform: + image = self.transform(image) + return image, sample[self.base_dataset.LABEL_COLUMN] + + def __len__(self) -> int: + return len(self.base_dataset) + + +class TcgaCrck_TilesDatasetWithReturnIndex(InnerEyeDataClassBaseWithReturnIndex, + TcgaCrck_TilesDatasetReturnImageLabel): + """ + Any dataset used in SSL needs to inherit from InnerEyeDataClassBaseWithReturnIndex as well as VisionData. + This class is just a shorthand notation for this double inheritance. Please note that this class needs + to override __getitem__(), this is why we need a separate PandaTilesDatasetReturnImageLabel. + """ + @property + def num_classes(self) -> int: + return 2 diff --git a/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py b/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py new file mode 100644 index 000000000..edb47d644 --- /dev/null +++ b/InnerEye/ML/Histopathology/datasets/tcga_prad_dataset.py @@ -0,0 +1,57 @@ +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import pandas as pd +from torch.utils.data import Dataset + + +class TcgaPradDataset(Dataset): + """Dataset class for loading TCGA-PRAD slides. + + Iterating over this dataset returns a dictionary containing: + - `'slide_id'` (str) + - `'case_id'` (str) + - `'image_path'` (str): absolute slide image path + - `'label'` (int, 0 or 1): label for predicting positive or negative + """ + SLIDE_ID_COLUMN: str = 'slide_id' + CASE_ID_COLUMN: str = 'case_id' + IMAGE_COLUMN: str = 'image_path' + LABEL_COLUMN: str = 'label' + + DEFAULT_CSV_FILENAME: str = "dataset.csv" + + def __init__(self, root_dir: Union[str, Path], + dataset_csv: Optional[Union[str, Path]] = None, + dataset_df: Optional[pd.DataFrame] = None,) -> None: + """ + :param root: Root directory of the dataset. + :param dataset_csv: Full path to a dataset CSV file. If omitted, the CSV will be read from + `"{root}/{DEFAULT_CSV_FILENAME}"`. + :param dataset_df: A potentially pre-processed dataframe in the same format as would be read + from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`. + """ + self.root_dir = Path(root_dir) + + if dataset_df is not None: + self.dataset_csv = None + else: + self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME + dataset_df = pd.read_csv(self.dataset_csv) + + dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN) + dataset_df[self.LABEL_COLUMN] = (dataset_df['label1_mutation'] + | dataset_df['label2_mutation']).astype(int) + self.dataset_df = dataset_df + + def __len__(self) -> int: + return self.dataset_df.shape[0] + + def __getitem__(self, index: int) -> Dict[str, Any]: + slide_id = self.dataset_df.index[index] + sample = { + self.SLIDE_ID_COLUMN: slide_id, + **self.dataset_df.loc[slide_id].to_dict() + } + sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN)) + return sample diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py new file mode 100644 index 000000000..55663d074 --- /dev/null +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -0,0 +1,319 @@ +from pathlib import Path +import pandas as pd +import numpy as np +from typing import Any, Callable, Dict, Optional, Tuple, List +import torch + +from pytorch_lightning import LightningModule +from torch import Tensor, argmax, mode, nn, no_grad, optim, round +from torchmetrics import AUROC, F1, Accuracy, Precision, Recall + +from InnerEye.Common import fixed_paths +from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset +from InnerEye.ML.Histopathology.models.encoders import TileEncoder +from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist +from InnerEye.ML.Histopathology.utils.naming import ResultsKey + + +RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB, + ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN] + + +def _format_cuda_memory_stats() -> str: + return (f"GPU {torch.cuda.current_device()} memory: " + f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB allocated, " + f"{torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB reserved") + + +class DeepMILModule(LightningModule): + """Base class for deep multiple-instance learning""" + + def __init__(self, + label_column: str, + n_classes: int, + encoder: TileEncoder, + pooling_layer: Callable[[int, int, int], nn.Module], + pool_hidden_dim: int = 128, + pool_out_dim: int = 1, + class_weights: Optional[Tensor] = None, + l_rate: float = 5e-4, + weight_decay: float = 1e-4, + adam_betas: Tuple[float, float] = (0.9, 0.99), + verbose: bool = False, + ) -> None: + """ + :param label_column: Label key for input batch dictionary. + :param n_classes: Number of output classes for MIL prediction. + :param encoder: The tile encoder to use for feature extraction. If no encoding is needed, + you should use `IdentityEncoder`. + :param pooling_layer: Type of pooling to use in multi-instance aggregation. Should be a + `torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions. + :param pool_hidden_dim: Hidden dimension of pooling layer (default=128). + :param pool_out_dim: Output dimension of pooling layer (default=1). + :param class_weights: Tensor containing class weights (default=None). + :param l_rate: Optimiser learning rate. + :param weight_decay: Weight decay parameter for L2 regularisation. + :param adam_betas: Beta parameters for Adam optimiser. + :param verbose: if True statements about memory usage are output at each step + """ + super().__init__() + + # Dataset specific attributes + self.label_column = label_column + self.n_classes = n_classes + self.pool_hidden_dim = pool_hidden_dim + self.pool_out_dim = pool_out_dim + self.pooling_layer = pooling_layer + self.class_weights = class_weights + self.encoder = encoder + self.num_encoding = self.encoder.num_encoding + + # Optimiser hyperparameters + self.l_rate = l_rate + self.weight_decay = weight_decay + self.adam_betas = adam_betas + + self.save_hyperparameters() + self.verbose = verbose + + self.aggregation_fn, self.num_pooling = self.get_pooling() + self.classifier_fn = self.get_classifier() + self.loss_fn = self.get_loss() + self.activation_fn = self.get_activation() + + # Metrics Objects + self.train_metrics = self.get_metrics() + self.val_metrics = self.get_metrics() + self.test_metrics = self.get_metrics() + + def get_pooling(self) -> Tuple[Callable, int]: + pooling_layer = self.pooling_layer(self.num_encoding, + self.pool_hidden_dim, + self.pool_out_dim) + num_features = self.num_encoding*self.pool_out_dim + return pooling_layer, num_features + + def get_classifier(self) -> Callable: + return nn.Linear(in_features=self.num_pooling, + out_features=self.n_classes) + + def get_loss(self) -> Callable: + if self.n_classes > 1: + return nn.CrossEntropyLoss(weight=self.class_weights) + else: + pos_weight = None + if self.class_weights is not None: + pos_weight = Tensor([self.class_weights[1]/(self.class_weights[0]+1e-5)]) + return nn.BCEWithLogitsLoss(pos_weight=pos_weight) + + def get_activation(self) -> Callable: + if self.n_classes > 1: + return nn.Softmax() + else: + return nn.Sigmoid() + + @staticmethod + def get_bag_label(labels: Tensor) -> Tensor: + # Get bag (batch) labels as majority vote + bag_label = mode(labels).values + return bag_label.view(1) + + def get_metrics(self) -> nn.ModuleDict: + if self.n_classes > 1: + return nn.ModuleDict({'accuracy': Accuracy(num_classes=self.n_classes, average='micro'), + 'macro_accuracy': Accuracy(num_classes=self.n_classes, average='macro'), + 'weighted_accuracy': Accuracy(num_classes=self.n_classes, average='weighted')}) + else: + return nn.ModuleDict({'accuracy': Accuracy(), + 'auroc': AUROC(num_classes=self.n_classes), + 'precision': Precision(), + 'recall': Recall(), + 'f1score': F1()}) + + def log_metrics(self, + stage: str) -> None: + valid_stages = ['train', 'test', 'val'] + if stage not in valid_stages: + raise Exception(f"Invalid stage. Chose one of {valid_stages}") + for metric_name, metric_object in self.get_metrics_dict(stage).items(): + self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) + + def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore + with no_grad(): + H = self.encoder(images) # N X L x 1 x 1 + A, M = self.aggregation_fn(H) # A: K x N | M: K x L + M = M.view(-1, self.num_encoding * self.pool_out_dim) + Y_prob = self.classifier_fn(M) + return Y_prob, A + + def configure_optimizers(self) -> optim.Optimizer: + return optim.Adam(self.parameters(), lr=self.l_rate, weight_decay=self.weight_decay, + betas=self.adam_betas) + + def get_metrics_dict(self, stage: str) -> nn.ModuleDict: + return getattr(self, f'{stage}_metrics') + + def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsKey, Tensor]: + # The batch dict contains lists of tensors of different sizes, for all bags in the batch. + # This means we can't stack them along a new axis without padding to the same length. + # We could alternatively concatenate them, but this would require other changes (e.g. in + # the attention layers) to correctly split the tensors by bag/slide ID. + bag_labels_list = [] + bag_logits_list = [] + bag_attn_list = [] + for bag_idx in range(len(batch[TilesDataset.LABEL_COLUMN])): + images = batch[TilesDataset.IMAGE_COLUMN][bag_idx] + labels = batch[self.label_column][bag_idx] + bag_labels_list.append(self.get_bag_label(labels)) + logit, attn = self(images) + bag_logits_list.append(logit.view(-1)) + bag_attn_list.append(attn) + bag_logits = torch.stack(bag_logits_list) + bag_labels = torch.stack(bag_labels_list).view(-1) + + if self.n_classes > 1: + loss = self.loss_fn(bag_logits, bag_labels) + else: + loss = self.loss_fn(bag_logits.squeeze(1), bag_labels.float()) + + probs = self.activation_fn(bag_logits) + if self.n_classes > 1: + preds = argmax(probs, dim=1) + else: + preds = round(probs) + + loss = loss.view(-1, 1) + preds = preds.view(-1, 1) + probs = probs.view(-1, 1) + bag_labels = bag_labels.view(-1, 1) + + results = dict() + for metric_object in self.get_metrics_dict(stage).values(): + metric_object.update(preds, bag_labels) + results.update({ResultsKey.SLIDE_ID: batch[TilesDataset.SLIDE_ID_COLUMN], + ResultsKey.TILE_ID: batch[TilesDataset.TILE_ID_COLUMN], + ResultsKey.IMAGE_PATH: batch[TilesDataset.PATH_COLUMN], ResultsKey.LOSS: loss, + ResultsKey.PROB: probs, ResultsKey.PRED_LABEL: preds, + ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list, + ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]}) + return results + + def training_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore + train_result = self._shared_step(batch, batch_idx, 'train') + self.log('train/loss', train_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, + sync_dist=True) + if self.verbose: + print(f"After loading images batch {batch_idx} -", _format_cuda_memory_stats()) + self.log_metrics('train') + return train_result[ResultsKey.LOSS] + + def validation_step(self, batch: Dict, batch_idx: int) -> Tensor: # type: ignore + val_result = self._shared_step(batch, batch_idx, 'val') + self.log('val/loss', val_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, + sync_dist=True) + self.log_metrics('val') + return val_result[ResultsKey.LOSS] + + def test_step(self, batch: Dict, batch_idx: int) -> Dict[ResultsKey, Any]: # type: ignore + test_result = self._shared_step(batch, batch_idx, 'test') + self.log('test/loss', test_result[ResultsKey.LOSS], on_epoch=True, on_step=True, logger=True, + sync_dist=True) + self.log_metrics('test') + return test_result + + def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore + # outputs object consists of a list of dictionaries (of metadata and results, including encoded features) + # It can be indexed as outputs[batch_idx][batch_key][bag_idx][tile_idx] + # example of batch_key ResultsKey.SLIDE_ID_COL + # for batch keys that contains multiple values for slides e.g. ResultsKey.BAG_ATTN_COL + # outputs[batch_idx][batch_key][bag_idx][tile_idx] + # contains the tile value + + # collate the batches + results: Dict[str, List[Any]] = {} + [results.update({col: []}) for col in outputs[0].keys()] + for key in results.keys(): + for batch_id in range(len(outputs)): + results[key] += outputs[batch_id][key] + + print("Saving outputs ...") + # collate at slide level + list_slide_dicts = [] + list_encoded_features = [] + # any column can be used here, the assumption is that the first dimension is the N of slides + for slide_idx in range(len(results[ResultsKey.SLIDE_ID])): + slide_dict = dict() + for key in results.keys(): + if key not in [ResultsKey.IMAGE, ResultsKey.LOSS]: + slide_dict[key] = results[key][slide_idx] + list_slide_dicts.append(slide_dict) + list_encoded_features.append(results[ResultsKey.IMAGE][slide_idx]) + + print(f"Metrics results will be output to {fixed_paths.repository_root_directory()}/outputs") + csv_filename = fixed_paths.repository_root_directory() / Path('outputs/test_output.csv') + encoded_features_filename = fixed_paths.repository_root_directory() / Path('outputs/test_encoded_features.pickle') + + # Collect the list of dictionaries in a list of pandas dataframe and save + df_list = [] + for slide_dict in list_slide_dicts: + slide_dict = self.normalize_dict_for_df(slide_dict, use_gpu=False) + df_list.append(pd.DataFrame.from_dict(slide_dict)) + df = pd.concat(df_list, ignore_index=True) + df.to_csv(csv_filename, mode='w', header=True) + + # Collect all features in a list and save + features_list = self.move_list_to_device(list_encoded_features, use_gpu=False) + torch.save(features_list, encoded_features_filename) + + print("Selecting tiles ...") + fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att')) + fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att')) + tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highes_pred', 'highest_att')) + tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'lowest_att')) + report_cases = {'TP': [tp_top_tiles, tp_bottom_tiles], 'FN': [fn_top_tiles, fn_bottom_tiles]} + + for key in report_cases.keys(): + print(f"Plotting {key} ...") + output_path = Path(fixed_paths.repository_root_directory(), f'outputs/fig/{key}/') + Path(output_path).mkdir(parents=True, exist_ok=True) + nslides = len(report_cases[key][0]) + for i in range(nslides): + slide, score, paths, top_attn = report_cases[key][0][i] + fig = plot_slide_noxy(slide, score, paths, top_attn, key + '_top', ncols=4) + figpath = Path(output_path, f'{slide}_top.png') + fig.savefig(figpath, bbox_inches='tight') + + slide, score, paths, bottom_attn = report_cases[key][1][i] + fig = plot_slide_noxy(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) + figpath = Path(output_path, f'{slide}_bottom.png') + fig.savefig(figpath, bbox_inches='tight') + + print("Plotting histogram ...") + fig = plot_scores_hist(results) + output_path = Path(fixed_paths.repository_root_directory(), 'outputs/fig/hist_scores.png') + fig.savefig(output_path, bbox_inches='tight') + + @staticmethod + def normalize_dict_for_df(dict_old: Dict[str, Any], use_gpu: bool) -> Dict: + # slide-level dictionaries are processed by making value dimensions uniform and converting to numpy arrays. + # these steps are required to convert the dictionary to pandas dataframe. + device = 'cuda' if use_gpu else 'cpu' + dict_new = dict() + for key, value in dict_old.items(): + if isinstance(value, Tensor): + value = value.squeeze(0).to(device).numpy() + if value.ndim == 0: + bag_size = len(dict_old[ResultsKey.SLIDE_ID]) + value = np.full(bag_size, fill_value=value) + dict_new[key] = value + return dict_new + + @staticmethod + def move_list_to_device(list_encoded_features: List, use_gpu: bool) -> List: + # a list of features on cpu obtained from original list on gpu + features_list = [] + device = 'cuda' if use_gpu else 'cpu' + for feature in list_encoded_features: + feature = feature.squeeze(0).to(device) + features_list.append(feature) + return features_list diff --git a/InnerEye/ML/Histopathology/models/encoders.py b/InnerEye/ML/Histopathology/models/encoders.py new file mode 100644 index 000000000..43f85772d --- /dev/null +++ b/InnerEye/ML/Histopathology/models/encoders.py @@ -0,0 +1,138 @@ +from pathlib import Path +from typing import Callable, Optional, Sequence, Tuple + +import numpy as np +import torch +from pl_bolts.models.self_supervised import SimCLR +from torch import nn +from torchvision.models import resnet18 +from torchvision.transforms import Compose + +from InnerEye.ML.Histopathology.utils.layer_utils import (get_imagenet_preprocessing, + load_weights_to_model, + setup_feature_extractor) +from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier +from InnerEye.ML.SSL.utils import create_ssl_image_classifier + + +class TileEncoder(nn.Module): + """Base tile encoder class for use in dataset transforms or as part of a bigger model""" + + def __init__(self, tile_size: int = 0, n_channels: int = 3, + input_dim: Optional[Sequence[int]] = None) -> None: + """The `TileEncoder` constructor should be called after setting any attributes needed in + `_get_preprocessing()` or `_get_encoder()`. + + :param tile_size: Tile width/height, in pixels. + :param n_channels: Number of channels in the tile (default=3). + :param input_dim: Input shape, to override default of `(n_channels, tile_size, tile_size)`. + """ + super().__init__() + if input_dim is None: + if tile_size == 0: + raise ValueError("Either input_dim or tile_size must be specified") + input_dim = (n_channels, tile_size, tile_size) + self.input_dim = tuple(input_dim) + + self.preprocessing_fn = self._get_preprocessing() + self.feature_extractor_fn, self.num_encoding = self._get_encoder() + + def _get_preprocessing(self) -> Callable: + return Compose([]) + + def _get_encoder(self) -> Tuple[Callable, int]: + raise NotImplementedError + + def forward(self, images: torch.Tensor) -> torch.Tensor: + prep_images = self.preprocessing_fn(images) + return self.feature_extractor_fn(prep_images) + + +class IdentityEncoder(TileEncoder): + """Dummy encoder that just flattens the input""" + + def _get_encoder(self) -> Tuple[Callable, int]: + return nn.Flatten(), np.prod(self.input_dim) + + +class ImageNetEncoder(TileEncoder): + """Feature extractor pretrained for classification on ImageNet""" + + def __init__(self, feature_extraction_model: Callable[..., nn.Module], + tile_size: int, n_channels: int = 3) -> None: + """ + :param feature_extraction_model: A function accepting a `pretrained` keyword argument that + returns a classifier pretrained on ImageNet, such as the ones from `torchvision.models.*`. + :param tile_size: Tile width/height, in pixels. + :param n_channels: Number of channels in the tile (default=3). + """ + self.create_feature_extractor_fn = feature_extraction_model + super().__init__(tile_size=tile_size, n_channels=n_channels) + + def _get_preprocessing(self) -> Callable: + return get_imagenet_preprocessing() + + def _get_encoder(self) -> Tuple[Callable, int]: + pretrained_model = self.create_feature_extractor_fn(pretrained=True) + return setup_feature_extractor(pretrained_model, self.input_dim) # type: ignore + + +class ImageNetSimCLREncoder(TileEncoder): + """SimCLR encoder pretrained on ImageNet""" + + WEIGHTS_URL = ("https://pl-bolts-weights.s3.us-east-2.amazonaws.com/" + "simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt") + EMBEDDING_DIM = 2048 + + def _get_preprocessing(self) -> Callable: + return get_imagenet_preprocessing() + + def _get_encoder(self) -> Tuple[SimCLR, int]: + simclr = SimCLR.load_from_checkpoint(self.WEIGHTS_URL, strict=False) + simclr.freeze() + return simclr, self.EMBEDDING_DIM + + +class InnerEyeSSLEncoder(TileEncoder): + """SSL encoder trained on Azure ML using InnerEye""" + + def __init__(self, pl_checkpoint_path: Path, tile_size: int, n_channels: int = 3) -> None: + """ + :param pl_checkpoint_path: The path of the downloaded checkpoint file. + :param tile_size: Tile width/height, in pixels. + :param n_channels: Number of channels in the tile (default=3). + """ + self.pl_checkpoint_path = pl_checkpoint_path + super().__init__(tile_size=tile_size, n_channels=n_channels) + + def _get_encoder(self) -> Tuple[torch.nn.Module, int]: + model: SSLClassifier = create_ssl_image_classifier( # type: ignore + num_classes=1, # dummy value + freeze_encoder=True, + pl_checkpoint_path=str(self.pl_checkpoint_path) + ) + encoder = model.encoder # type: ignore + for param in encoder.parameters(): + param.requires_grad = False # freeze_encoder does not disable gradients + + classifier_head = model.classifier_head + embedding_dim = classifier_head.n_input # type: ignore + + return encoder, embedding_dim + + +class HistoSSLEncoder(TileEncoder): + """HistoSSL encoder pretrained on multiple histological datasets + + Reference: + - Ciga, Xu, Martel (2021). Self supervised contrastive learning for digital histopathology. + arXiv:2011.13971 + """ + + WEIGHTS_URL = ("https://github.com/ozanciga/self-supervised-histopathology/releases/" + "download/tenpercent/tenpercent_resnet18.ckpt") + + def _get_encoder(self) -> Tuple[Callable, int]: + resnet18_model = resnet18(pretrained=False) + histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model) + return setup_feature_extractor(histossl_encoder, self.input_dim) # type: ignore diff --git a/InnerEye/ML/Histopathology/models/transforms.py b/InnerEye/ML/Histopathology/models/transforms.py new file mode 100644 index 000000000..51cda8c4e --- /dev/null +++ b/InnerEye/ML/Histopathology/models/transforms.py @@ -0,0 +1,107 @@ +from pathlib import Path +from typing import Mapping, Sequence, Union + +import PIL.Image +import torch +from monai.config.type_definitions import KeysCollection +from monai.transforms.transform import MapTransform +from torchvision.transforms.functional import to_tensor + +from InnerEye.ML.Histopathology.models.encoders import TileEncoder + +PathOrString = Union[Path, str] + + +def load_pil_image(image_path: PathOrString) -> PIL.Image.Image: + """Load a PIL image in RGB format from the given path""" + return PIL.Image.open(image_path).convert('RGB') + + +def load_image_as_tensor(image_path: PathOrString) -> torch.Tensor: + """Load an image as a tensor from the given path""" + pil_image = load_pil_image(image_path) + return to_tensor(pil_image) + + +def load_image_stack_as_tensor(image_paths: Sequence[PathOrString], + progress: bool = False) -> torch.Tensor: + """Load a batch of images of the same size as a tensor from the given paths""" + loading_generator = (load_image_as_tensor(path) for path in image_paths) + if progress: + from tqdm import tqdm + loading_generator = tqdm(loading_generator, desc="Loading image stack", + total=len(image_paths), leave=False) + image_tensors = list(loading_generator) + return torch.stack(image_tensors, dim=0) + + +class LoadTiled(MapTransform): + """Dictionary transform to load an individual image tile as a tensor from an input path""" + + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + """ + :param keys: Key(s) for the image path(s) in the input dictionary. + :param allow_missing_keys: If `False` (default), raises an exception when an input + dictionary is missing any of the specified keys. + """ + super().__init__(keys, allow_missing_keys) + + def __call__(self, data: Mapping) -> Mapping: + out_data = dict(data) # create shallow copy + for key in self.key_iterator(out_data): + out_data[key] = load_image_as_tensor(data[key]) + return out_data + + +class LoadTilesBatchd(MapTransform): + """Dictionary transform to load a batch of image tiles as a tensor from a list of input paths""" + + # Cannot reuse MONAI readers because they support stacking only images with no channels + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, + progress: bool = False) -> None: + """ + :param keys: Key(s) for the image path(s) in the input dictionary. + :param allow_missing_keys: If `False` (default), raises an exception when an input + dictionary is missing any of the specified keys. + :param progress: Whether to display a tqdm progress bar. + """ + super().__init__(keys, allow_missing_keys) + self.progress = progress + + def __call__(self, data: Mapping) -> Mapping: + out_data = dict(data) # create shallow copy + for key in self.key_iterator(out_data): + out_data[key] = load_image_stack_as_tensor(data[key], progress=self.progress) + return out_data + + +class EncodeTilesBatchd(MapTransform): + """Dictionary transform to extract features from a batch tensor of image tiles""" + + def __init__(self, + keys: KeysCollection, + encoder: TileEncoder, + allow_missing_keys: bool = False) -> None: + """ + :param keys: Key(s) for the image path(s) in the input dictionary. + :param encoder: The tile encoder to use for feature extraction. + :param allow_missing_keys: If `False` (default), raises an exception when an input + dictionary is missing any of the specified keys. + """ + super().__init__(keys, allow_missing_keys) + self.encoder = encoder + + @torch.no_grad() + def _encode_tiles(self, images: torch.Tensor) -> torch.Tensor: + device = next(self.encoder.parameters()).device + images = images.to(device) + embeddings = self.encoder(images) + del images + torch.cuda.empty_cache() + return embeddings + + def __call__(self, data: Mapping) -> Mapping: + out_data = dict(data) # create shallow copy + for key in self.key_iterator(out_data): + out_data[key] = self._encode_tiles(data[key]) + return out_data diff --git a/InnerEye/ML/Histopathology/preprocessing/__init__.py b/InnerEye/ML/Histopathology/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py b/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py new file mode 100644 index 000000000..e1060860a --- /dev/null +++ b/InnerEye/ML/Histopathology/preprocessing/create_tiles_dataset.py @@ -0,0 +1,222 @@ +import functools +import os +import logging +import shutil +import traceback +import warnings +from pathlib import Path +from typing import Sequence, Tuple, Union + +import numpy as np +import PIL +from monai.data import Dataset +from monai.data.image_reader import WSIReader +from tqdm import tqdm + +from InnerEye.ML.Histopathology.preprocessing import tiling +from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId + + +CSV_COLUMNS = ['slide_id', 'tile_id', 'image', 'mask', 'tile_x', 'tile_y', 'occupancy', + 'data_provider', 'slide_isup_grade', 'slide_gleason_score'] +TMP_SUFFIX = "_tmp" + +logging.basicConfig(format='%(asctime)s %(message)s', filemode='w') +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + + +def select_tile(mask_tile: np.ndarray, occupancy_threshold: float) \ + -> Union[Tuple[bool, float], Tuple[np.ndarray, np.ndarray]]: + if occupancy_threshold < 0. or occupancy_threshold > 1.: + raise ValueError("Tile occupancy threshold must be between 0 and 1") + foreground_mask = mask_tile > 0 + occupancy = foreground_mask.mean(axis=(-2, -1)) + return (occupancy > occupancy_threshold).squeeze(), occupancy.squeeze() + + +def get_tile_descriptor(tile_location: Sequence[int]) -> str: + return f"{tile_location[0]:05d}x_{tile_location[1]:05d}y" + + +def get_tile_id(slide_id: str, tile_location: Sequence[int]) -> str: + return f"{slide_id}.{get_tile_descriptor(tile_location)}" + + +def save_image(array_chw: np.ndarray, path: Path) -> PIL.Image: + path.parent.mkdir(parents=True, exist_ok=True) + array_hwc = np.moveaxis(array_chw, 0, -1).astype(np.uint8).squeeze() + pil_image = PIL.Image.fromarray(array_hwc) + pil_image.convert('RGB').save(path) + return pil_image + + +def generate_tiles(sample: dict, tile_size: int, occupancy_threshold: float) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]: + image_tiles, tile_locations = tiling.tile_array_2d(sample['image'], tile_size=tile_size, + constant_values=255) + mask_tiles, _ = tiling.tile_array_2d(sample['mask'], tile_size=tile_size, constant_values=0) + + selected: np.ndarray + occupancies: np.ndarray + selected, occupancies = select_tile(mask_tiles, occupancy_threshold) + n_discarded = (~selected).sum() + logging.info(f"Percentage tiles discarded: {round(selected.sum() / n_discarded * 100, 2)}") + + image_tiles = image_tiles[selected] + mask_tiles = mask_tiles[selected] + tile_locations = tile_locations[selected] + occupancies = occupancies[selected] + + abs_tile_locations = (sample['scale'] * tile_locations + sample['location']).astype(int) + + return image_tiles, mask_tiles, abs_tile_locations, occupancies, n_discarded + + +# TODO refactor this to separate metadata identification from saving. We might want the metadata +# even if the saving fails +def save_tile(sample: dict, image_tile: np.ndarray, mask_tile: np.ndarray, + tile_location: Sequence[int], output_dir: Path) -> dict: + slide_id = sample['image_id'] + descriptor = get_tile_descriptor(tile_location) + image_tile_filename = f"train_images/{descriptor}.png" + mask_tile_filename = f"train_label_masks/{descriptor}_mask.png" + + save_image(image_tile, output_dir / image_tile_filename) + save_image(mask_tile, output_dir / mask_tile_filename) + + tile_metadata = { + 'slide_id': slide_id, + 'tile_id': get_tile_id(slide_id, tile_location), + 'image': image_tile_filename, + 'mask': mask_tile_filename, + 'tile_x': tile_location[0], + 'tile_y': tile_location[1], + 'data_provider': sample['data_provider'], + 'slide_isup_grade': sample['isup_grade'], + 'slide_gleason_score': sample['gleason_score'], + } + + return tile_metadata + + +def process_slide(sample: dict, level: int, margin: int, tile_size: int, occupancy_threshold: int, + output_dir: Path, tile_progress: bool = False) -> None: + slide_id = sample['image_id'] + slide_dir: Path = output_dir / (slide_id + "/") + logging.info(f">>> Slide dir {slide_dir}") + if slide_dir.exists(): # already processed slide - skip + logging.info(f">>> Skipping {slide_dir} - already processed") + return + else: + try: + slide_dir.mkdir(parents=True) + + dataset_csv_path = slide_dir / "dataset.csv" + dataset_csv_file = dataset_csv_path.open('w') + dataset_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header + + tiles_failure = 0 + failed_tiles_csv_path = slide_dir / "failed_tiles.csv" + failed_tiles_file = failed_tiles_csv_path.open('w') + failed_tiles_file.write('tile_id' + '\n') + + logging.info(f"Loading slide {slide_id} ...") + loader = LoadPandaROId(WSIReader(), level=level, margin=margin) + sample = loader(sample) # load 'image' and 'mask' from disk + + logging.info(f"Tiling slide {slide_id} ...") + image_tiles, mask_tiles, tile_locations, occupancies, _ = \ + generate_tiles(sample, tile_size, occupancy_threshold) + n_tiles = image_tiles.shape[0] + + for i in tqdm(range(n_tiles), f"Tiles ({slide_id[:6]}…)", unit="img", disable=not tile_progress): + try: + tile_metadata = save_tile(sample, image_tiles[i], mask_tiles[i], tile_locations[i], + slide_dir) + tile_metadata['occupancy'] = occupancies[i] + tile_metadata['image'] = os.path.join(slide_dir.name, tile_metadata['image']) + tile_metadata['mask'] = os.path.join(slide_dir.name, tile_metadata['mask']) + dataset_row = ','.join(str(tile_metadata[column]) for column in CSV_COLUMNS) + dataset_csv_file.write(dataset_row + '\n') + except Exception as e: + tiles_failure += 1 + descriptor = get_tile_descriptor(tile_locations[i]) + '\n' + failed_tiles_file.write(descriptor) + traceback.print_exc() + warnings.warn(f"An error occurred while saving tile " + f"{get_tile_id(slide_id, tile_locations[i])}: {e}") + + dataset_csv_file.close() + failed_tiles_file.close() + if tiles_failure > 0: + # TODO what we want to do with slides that have some failed tiles? + logging.warning(f"{slide_id} is incomplete. {tiles_failure} tiles failed.") + except Exception as e: + traceback.print_exc() + warnings.warn(f"An error occurred while processing slide {slide_id}: {e}") + + +def merge_dataset_csv_files(dataset_dir: Path) -> Path: + full_csv = dataset_dir / "dataset.csv" + # TODO change how we retrieve these filenames, probably because mounted, the operation is slow + # and it seems to find many more files + # print("List of files") + # print([str(file) + '\n' for file in dataset_dir.glob("*/dataset.csv")]) + with full_csv.open('w') as full_csv_file: + # full_csv_file.write(','.join(CSV_COLUMNS) + '\n') # write CSV header + first_file = True + for slide_csv in tqdm(dataset_dir.glob("*/dataset.csv"), desc="Merging dataset.csv", unit='file'): + logging.info(f"Merging slide {slide_csv}") + content = slide_csv.read_text() + if not first_file: + content = content[content.index('\n') + 1:] # discard header row for all but the first file + full_csv_file.write(content) + first_file = False + return full_csv + + +def main(panda_dir: Union[str, Path], root_output_dir: Union[str, Path], level: int, tile_size: int, + margin: int, occupancy_threshold: float, parallel: bool = False, overwrite: bool = False) -> None: + + # Ignoring some types here because mypy is getting confused with the MONAI Dataset class + # to select a subsample use keyword n_slides + dataset = Dataset(PandaDataset(panda_dir)) # type: ignore + + output_dir = Path(root_output_dir) / f"panda_tiles_level{level}_{tile_size}" + logging.info(f"Creating dataset of level-{level} {tile_size}x{tile_size} PANDA tiles at: {output_dir}") + + if overwrite and output_dir.exists(): + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=not overwrite) + + func = functools.partial(process_slide, level=level, margin=margin, tile_size=tile_size, + occupancy_threshold=occupancy_threshold, output_dir=output_dir, + tile_progress=not parallel) + + if parallel: + import multiprocessing + + pool = multiprocessing.Pool() + map_func = pool.imap_unordered # type: ignore + else: + map_func = map # type: ignore + + list(tqdm(map_func(func, dataset), desc="Slides", unit="img", total=len(dataset))) # type: ignore + + if parallel: + pool.close() + + logging.info("Merging slide files in a single file") + merge_dataset_csv_files(output_dir) + + +if __name__ == '__main__': + main(panda_dir="/tmp/datasets/PANDA", + root_output_dir="/datadrive", + level=1, + tile_size=224, + margin=64, + occupancy_threshold=0.05, + parallel=True, + overwrite=False) diff --git a/InnerEye/ML/Histopathology/preprocessing/tiling.py b/InnerEye/ML/Histopathology/preprocessing/tiling.py new file mode 100644 index 000000000..b0f8b6c37 --- /dev/null +++ b/InnerEye/ML/Histopathology/preprocessing/tiling.py @@ -0,0 +1,123 @@ +# These tiling implementations are adapted from PANDA Kaggle solutions, for example: +# https://github.com/kentaroy47/Kaggle-PANDA-1st-place-solution/blob/master/src/data_process/a00_save_tiles.py +from typing import Any, Optional, Tuple + +import numpy as np + + +def get_1d_padding(length: int, tile_size: int) -> Tuple[int, int]: + """Computes symmetric padding for `length` to be divisible by `tile_size`.""" + pad = (tile_size - length % tile_size) % tile_size + return (pad // 2, pad - pad // 2) + + +def pad_for_tiling_2d(array: np.ndarray, tile_size: int, channels_first: Optional[bool] = True, + **pad_kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: + """Symmetrically pads a 2D `array` such that both dimensions are divisible by `tile_size`. + + :param array: 2D image array. + :param tile_size: Width/height of each tile in pixels. + :param channels_first: Whether `array` is in CHW (`True`, default) or HWC (`False`) layout. + :param pad_kwargs: Keyword arguments to be passed to `np.pad()` (e.g. `constant_values=0`). + :return: A tuple containing: + - `padded_array`: Resulting array, in the same CHW/HWC layout as the input. + - `offset`: XY offset introduced by the padding. Add this to coordinates relative to the + original array to obtain indices for the padded array. + """ + height, width = array.shape[1:] if channels_first else array.shape[:-1] + padding_h = get_1d_padding(height, tile_size) + padding_w = get_1d_padding(width, tile_size) + padding = [padding_h, padding_w] + channels_axis = 0 if channels_first else 2 + padding.insert(channels_axis, (0, 0)) # zero padding on channels axis + padded_array = np.pad(array, padding, **pad_kwargs) + offset = (padding_w[0], padding_h[0]) + return padded_array, np.array(offset) + + +def tile_array_2d(array: np.ndarray, tile_size: int, channels_first: Optional[bool] = True, + **pad_kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: + """Split an image array into square non-overlapping tiles. + + The array will be padded symmetrically if its dimensions are not exact multiples of `tile_size`. + + :param array: Image array. + :param tile_size: Width/height of each tile in pixels. + :param pad_kwargs: Keyword arguments to be passed to `np.pad()` (e.g. `constant_values=0`). + :param channels_first: Whether `array` is in CHW (`True`, default) or HWC (`False`) layout. + :return: A tuple containing: + - `tiles`: A batch of tiles in NCHW layout. + - `coords`: XY coordinates of each tile, in the same order. + """ + padded_array, (offset_w, offset_h) = pad_for_tiling_2d(array, tile_size, channels_first, **pad_kwargs) + if channels_first: + channels, height, width = padded_array.shape + else: + height, width, channels = padded_array.shape + n_tiles_h = height // tile_size + n_tiles_w = width // tile_size + + if channels_first: + intermediate_shape = (channels, n_tiles_h, tile_size, n_tiles_w, tile_size) + axis_order = (1, 3, 0, 2, 4) # (n_tiles_h, n_tiles_w, channels, tile_size, tile_size) + output_shape = (n_tiles_h * n_tiles_w, channels, tile_size, tile_size) + else: + intermediate_shape = (n_tiles_h, tile_size, n_tiles_w, tile_size, channels) + axis_order = (0, 2, 1, 3, 4) # (n_tiles_h, n_tiles_w, tile_size, tile_size, channels) + output_shape = (n_tiles_h * n_tiles_w, tile_size, tile_size, channels) + + tiles = padded_array.reshape(intermediate_shape) # Split width and height axes + tiles = tiles.transpose(axis_order) + tiles = tiles.reshape(output_shape) # Flatten tile batch dimension + + # Compute top-left coordinates of every tile, relative to the original array's origin + coords_h = tile_size * np.arange(n_tiles_h) - offset_h + coords_w = tile_size * np.arange(n_tiles_w) - offset_w + # Shape: (n_tiles_h * n_tiles_w, 2) + coords = np.stack(np.meshgrid(coords_w, coords_h), axis=-1).reshape(-1, 2) + + return tiles, coords + + +def assemble_tiles_2d(tiles: np.ndarray, coords: np.ndarray, fill_value: Optional[float] = np.nan, + channels_first: Optional[bool] = True) -> Tuple[np.ndarray, np.ndarray]: + """Assembles a 2D array from sequences of tiles and coordinates. + + :param tiles: Stack of tiles with batch dimension first. + :param coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]). + :param tile_size: Size of each tile; must be >0. + :param fill_value: Value to assign to empty elements (default: `NaN`). + :param channels_first: Whether each tile is in CHW (`True`, default) or HWC (`False`) layout. + :return: A tuple containing: + - `array`: The reassembled 2D array with the smallest dimensions to contain all given tiles. + - `offset`: The lowest XY coordinates. + - `offset`: XY offset introduced by the assembly. Add this to tile coordinates to obtain + indices for the assembled array. + """ + if coords.shape[0] != tiles.shape[0]: + raise ValueError(f"Tile coordinates and values must have the same length, " + f"got {coords.shape[0]} and {tiles.shape[0]}") + + if channels_first: + n_tiles, channels, tile_size, _ = tiles.shape + else: + n_tiles, tile_size, _, channels = tiles.shape + tile_xs, tile_ys = coords.T + + x_min, x_max = min(tile_xs), max(tile_xs + tile_size) + y_min, y_max = min(tile_ys), max(tile_ys + tile_size) + width = x_max - x_min + height = y_max - y_min + output_shape = (channels, height, width) if channels_first else (height, width, channels) + array = np.full(output_shape, fill_value) + + offset = np.array([-x_min, -y_min]) + for idx in range(n_tiles): + row = coords[idx, 1] + offset[1] + col = coords[idx, 0] + offset[0] + if channels_first: + array[:, row:row + tile_size, col:col + tile_size] = tiles[idx] + else: + array[row:row + tile_size, col:col + tile_size, :] = tiles[idx] + + return array, offset diff --git a/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py b/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py new file mode 100644 index 000000000..6fe4551ba --- /dev/null +++ b/InnerEye/ML/Histopathology/scripts/aggregate_metrics_crossvalidation.py @@ -0,0 +1,35 @@ +""" +Script to find mean and standard deviation of desired metrics from cross validation child runs. +""" +import os +import pandas as pd + +from health_azure import aggregate_hyperdrive_metrics, get_workspace + +from InnerEye.Common import fixed_paths + + +def get_cross_validation_metrics_df(run_id: str) -> pd.DataFrame: + """ + Function to aggregate the metric over cross-validation runs + :param run_id: run id of the hyperdrive run containing child runs + """ + aml_workspace = get_workspace() + os.chdir(fixed_paths.repository_root_directory()) + df = aggregate_hyperdrive_metrics(run_id=run_id, + child_run_arg_name="cross_validation_split_index", + aml_workspace=aml_workspace) + return df + + +if __name__ == "__main__": + metrics_list = ['test/accuracy', 'test/auroc', 'test/f1score', 'test/precision', 'test/recall'] + run_id = "hsharma_features_viz:HD_eff4c009-2f9f-4c2c-94c6-c0c84944a412" + metrics_df = get_cross_validation_metrics_df(run_id=run_id) + for metric in metrics_list: + if metric in metrics_df.index.values: + mean = metrics_df.loc[[metric]].mean(axis=1)[metric] + std = metrics_df.loc[[metric]].std(axis=1)[metric] + print(f"{metric}: {round(mean,4)} ± {round(std,4)}") + else: + print(f"Metric {metric} not found in the Hyperdrive run metrics for run id {run_id}.") diff --git a/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py b/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py new file mode 100644 index 000000000..c44f7ae6a --- /dev/null +++ b/InnerEye/ML/Histopathology/scripts/azure/azure_tiles_creation.py @@ -0,0 +1,56 @@ +""" +This script is an example of how to use the submit_to_azure_if_needed function from the hi-ml package to run the +main pre-processing function that creates tiles from slides in the PANDA dataset. The advantage of using this script +is the ability to submit to a cluster on azureml and to have the output files directly saved as a registered dataset. + +To run execute, from inside the pre-processing folder, +python azure_tiles_creation.py --azureml + +A json configuration file containing the credentials to the Azure workspace and an environment.yml file are expected +in input. + +This has been tested on hi-mlv0.1.4. +""" + +from pathlib import Path +import sys +import time + +current_file = Path(__file__) +radiomics_root = current_file.absolute().parent.parent.parent.parent.parent +sys.path.append(str(radiomics_root)) +from health_azure.himl import submit_to_azure_if_needed, DatasetConfig # noqa +from InnerEye.ML.Histopathology.preprocessing.create_tiles_dataset import main # noqa + +# Pre-built environment file that contains all the requirements (RadiomicsNN + histo) +# Assuming ENV_NAME is a complete environment, `conda env export -n ENV_NAME -f ENV_NAME.yml` will create the desired file +ENVIRONMENT_FILE = radiomics_root.joinpath(Path("/envs/innereyeprivatetiles.yml")) +DATASET_NAME = "PANDA_tiles" +timestr = time.strftime("%Y%m%d-%H%M%S") +folder_name = DATASET_NAME + '_' + timestr + +if __name__ == '__main__': + print(f"Running {str(current_file)}") + input_dataset = DatasetConfig(name="PANDA", datastore="innereyedatasets", local_folder=Path("/tmp/datasets/PANDA"), use_mounting=True) + output_dataset = DatasetConfig(name=DATASET_NAME, datastore="innereyedatasets", local_folder=Path("/datadrive/"), use_mounting=True) + run_info = submit_to_azure_if_needed(entry_script=current_file, + snapshot_root_directory=radiomics_root, + workspace_config_file=Path("config.json"), + compute_cluster_name='training-pr-nc12', # training-nd24 + default_datastore="innereyedatasets", + conda_environment_file=Path(ENVIRONMENT_FILE), + input_datasets=[input_dataset], + output_datasets=[output_dataset], + ) + input_folder = run_info.input_datasets[0] + output_folder = Path(run_info.output_datasets[0], folder_name) + print(f'This will be the final ouput folder {str(output_folder)}') + + main(panda_dir=str(input_folder), + root_output_dir=str(output_folder), + level=1, + tile_size=224, + margin=64, + occupancy_threshold=0.05, + parallel=True, + overwrite=False) diff --git a/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py b/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py new file mode 100644 index 000000000..4f9afb064 --- /dev/null +++ b/InnerEye/ML/Histopathology/scripts/mount_azure_dataset.py @@ -0,0 +1,25 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ +from health_azure import DatasetConfig +from health_azure.utils import get_workspace + + +def mount_dataset(dataset_id: str) -> str: + ws = get_workspace() + target_folder = "/tmp/datasets/" + dataset = DatasetConfig(name=dataset_id, target_folder=target_folder, use_mounting=True) + dataset_mount_folder, mount_ctx = dataset.to_input_dataset_local(ws) + mount_ctx.start() + assert next(dataset_mount_folder.iterdir()), "Mounted data folder is empty" + return str(dataset_mount_folder) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--id', type=str, dest='dataset_id', + help='Name of the Azure dataset e.g. PANDA or TCGA-CRCk') + args = parser.parse_args() + mount_dataset(args.dataset_id) diff --git a/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py b/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py new file mode 100644 index 000000000..0e9981357 --- /dev/null +++ b/InnerEye/ML/Histopathology/utils/analysis_plot_utils.py @@ -0,0 +1,93 @@ +import numpy as np +from typing import List, Any + +import umap +from sklearn.manifold import TSNE +from matplotlib import pyplot as plt + + +def get_tsne_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]: + """ + Get the t-sne projection of high dimensional data in a lower dimensional space + :param features: list of features in higher dimensional space (n x f for n samples and f features per sample) + :param **kwargs: keyword arguments to be passed to TSNE() + :return: list of features in lower dimensional space (n x c for n samples and c components) + """ + tsne_2d = TSNE(n_components=n_components, n_jobs=n_jobs, **kwargs) + tsne_proj = tsne_2d.fit_transform(features) + return tsne_proj + + +def get_umap_projection(features: List[Any], n_components: int = 2, n_jobs: int = -1, **kwargs: Any) -> List[Any]: + """ + Get the umap projection of high dimensional data in a lower dimensional space + :param features: list of features in higher dimensional space (n x f for n samples and f features per sample) + :param **kwargs: keyword arguments to be passed to UMAP() + :return: list of features in lower dimensional space (n x c for n samples and c components) + """ + umap_2d = umap.UMAP(n_components=n_components, n_jobs=n_jobs, **kwargs) + umap_proj = umap_2d.fit_transform(features) + return umap_proj + + +def normalize_array_minmax(arr: List[float]) -> List[float]: + """ + Normalize an array in range 0 to 1 + :param arr: array to be normalized + :return: normalized array + """ + return (arr - np.min(arr)) / (np.max(arr) - np.min(arr)) + + +def normalize_array_mean(arr: List[float]) -> List[float]: + """ + Normalize an array with zero mean and unit variance + :param arr: array to be normalized + :return: normalized array + """ + return (arr - np.mean(arr)) / np.std(arr) + + +def plot_projected_features_2d(data: Any, labels: List[int], classes: List[str], title: str = "") -> None: + """ + Plot a scatter plot of projected features in two dimensions + :param data: features projected in 2d space (nx2) + :param labels: corresponding labels of the data (nx1) + :param classes: list of classes in the dataset + :param title: plot title string + """ + plt.figure() + scatter = plt.scatter(data[:, 0], data[:, 1], 20, labels) + plt.legend(handles=scatter.legend_elements()[0], labels=classes) + plt.title(title) + + +def plot_box_whisker(data_list: List[Any], column_names: List[str], show_outliers: bool, title: str = "") -> None: + """ + Plot a box whisker plot of column data + :param columns: data to be plotted in columns + :param column_names: names of the columns + :param show_outliers: whether outliers need to be shown + :param title: plot title string + """ + plt.figure() + _, ax = plt.subplots() + ax.boxplot(data_list, showfliers=show_outliers) + positions = range(1, len(column_names)+1) + means = [] + for i in range(len(data_list)): + means.append(np.mean(data_list[i])) + ax.plot(positions, means, 'rs') + plt.xticks(positions, column_names) + plt.title(title) + + +def plot_histogram(data: List[Any], title: str = "") -> None: + """ + Plot a histogram given some data + :param data: data to be plotted + :param title: plot title string + """ + plt.figure() + plt.hist(data, bins=50) + plt.gca().set(title=title, xlabel='Values', ylabel='Frequency') diff --git a/InnerEye/ML/Histopathology/utils/download_utils.py b/InnerEye/ML/Histopathology/utils/download_utils.py new file mode 100644 index 000000000..10b80ebef --- /dev/null +++ b/InnerEye/ML/Histopathology/utils/download_utils.py @@ -0,0 +1,31 @@ +import os +from pathlib import Path + +from health_azure import download_files_from_run_id, get_workspace +from InnerEye.Common import fixed_paths + + +def download_file_if_necessary(run_id: str, remote_dir: Path, download_dir: Path, filename: str) -> None: + """ + Function to download any file from an AML run if it doesn't exist locally + :param run_id: run ID of the AML run + :param remote_dir: remote directory from where the file is downloaded + :param download_dir: local directory where to save the downloaded file + :param filename: name of the file to be downloaded (e.g. `"test_output.csv"`). + """ + aml_workspace = get_workspace() + os.chdir(fixed_paths.repository_root_directory()) + local_path = download_dir / run_id.split(":")[1] / "outputs" / filename + remote_path = remote_dir / filename + if local_path.exists(): + print("File already exists at", local_path) + else: + local_dir = local_path.parent.parent + local_dir.mkdir(exist_ok=True, parents=True) + download_files_from_run_id(run_id=run_id, + output_folder=local_dir, + prefix=str(remote_path), + aml_workspace=aml_workspace, + validate_checksum=True) + assert local_path.exists() + print("File is downloaded at", local_path) diff --git a/InnerEye/ML/Histopathology/utils/layer_utils.py b/InnerEye/ML/Histopathology/utils/layer_utils.py new file mode 100644 index 000000000..a2847617d --- /dev/null +++ b/InnerEye/ML/Histopathology/utils/layer_utils.py @@ -0,0 +1,45 @@ +from typing import Callable, Tuple + +from torch import as_tensor, device, nn, prod, rand +from torch.hub import load_state_dict_from_url +from torchvision.transforms import Normalize + + +def get_imagenet_preprocessing() -> nn.Module: + return Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + +def setup_feature_extractor(pretrained_model: nn.Module, + input_dim: Tuple[int, int, int]) -> Tuple[Callable, int]: + layers = list(pretrained_model.children())[:-1] + layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps + feature_extractor = nn.Sequential(*layers) + # fix weights, no fine-tuning + for param in feature_extractor.parameters(): + param.requires_grad = False + feature_shape = feature_extractor(rand(1, *input_dim)).shape + num_features = int(prod(as_tensor(feature_shape)).item()) + return feature_extractor, num_features + + +def load_weights_to_model(weights_url: str, model: nn.Module) -> nn.Module: + """ + Load weights to the histoSSL model from the given URL + https://github.com/ozanciga/self-supervised-histopathology + """ + map_location = device('cpu') + state = load_state_dict_from_url(weights_url, map_location=map_location) + state_dict = state['state_dict'] + model_dict = model.state_dict() + + new_weights = {} + for key, value in state_dict.items(): + model_key = key.replace('model.', '').replace('resnet.', '') + if model_key in model_dict: + new_weights[model_key] = value + if len(new_weights) == 0: + raise RuntimeError("Weights could not be loaded.") + model_dict.update(new_weights) # type: ignore + + model.load_state_dict(model_dict) # type: ignore + return model diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py new file mode 100644 index 000000000..cc42e9d42 --- /dev/null +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -0,0 +1,94 @@ +from typing import Tuple, List, Any, Dict +import torch +import matplotlib.pyplot as plt +from math import ceil + +from InnerEye.ML.Histopathology.models.transforms import load_pil_image +from InnerEye.ML.Histopathology.utils.naming import ResultsKey + + +def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1, + select: Tuple = ('lowest_pred', 'highest_att'), + slide_col: str = ResultsKey.SLIDE_ID, gt_col: str = ResultsKey.TRUE_LABEL, + attn_col: str = ResultsKey.BAG_ATTN, prob_col: str = ResultsKey.PROB, + return_col: str = ResultsKey.IMAGE_PATH) -> List[Tuple[Any, Any, List[Any], List[Any]]]: + """ + :param results: List that contains slide_level dicts + :param n_tiles: number of tiles to be selected for each slide + :param n_slides: number of slides to be selected + :param label: which label to use to select slides + :param select: criteria to be used to sort the slides (select[0]) and the tiles (select[1]) + :param slide_col: column name that contains slide identifiers + :param gt_col: column name that contains labels + :param attn_col: column name that contains scores used to sort tiles + :param prob_col: column name that contains scores used to sort slides + :param return_col: column name of the values we want to return for each tile + :return: tuple containing the slides id, the slide score, the tile ids, the tiles scores + """ + tmp_s = [(results[prob_col][i], i) for i, gt in enumerate(results[gt_col]) if gt == label] # type ignore + if select[0] == 'lowest_pred': + tmp_s.sort(reverse=False) + elif select[0] == 'highest_pred': + tmp_s.sort(reverse=True) + else: + ValueError('select value not recognised') + _, sorted_idx = zip(*tmp_s) + k_idx = [] + if select[1] == 'highest_att': + descending = True + elif select[1] == 'lowest_att': + descending = False + for _, slide_idx in enumerate(sorted_idx[:n_slides]): + tmp = results[attn_col][slide_idx] + _, t_indices = torch.sort(tmp, descending=descending) + k_tiles = [] + scores = [] + for t_idx in t_indices[0][:n_tiles]: + k_tiles.append(results[return_col][slide_idx][t_idx]) + scores.append(results[attn_col][slide_idx][0][t_idx]) + # slide_ids are duplicated + k_idx.append((results[slide_col][slide_idx][0], + results[prob_col][slide_idx].item(), + k_tiles, scores)) + return k_idx + + +def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.PROB, + gt_col: str = ResultsKey.TRUE_LABEL) -> plt.figure: + """ + :param results: List that contains slide_level dicts + :param prob_col: column name that contains the scores + :param gt_col: column name that contains the true label + :return: matplotlib figure of the scores histogram by class + """ + pos_scores = [results[prob_col][i][0].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == 1] + neg_scores = [results[prob_col][i][0].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == 0] + fig, ax = plt.subplots() + ax.hist([pos_scores, neg_scores], label=['1', '0'], alpha=0.5) + ax.set_xlabel("Predicted Score") + ax.legend() + return fig + + +def plot_slide_noxy(slide: str, score: float, paths: List, attn: List, case: str, ncols: int = 5, + size: Tuple = (10, 10)) -> plt.figure: + """ + :param slide: slide identifier + :param score: predicted score for the slide + :param paths: list of paths to tiles belonging to the slide + :param attn: list of scores belonging to the tiles in paths. paths and attn are expected to have the same shape + :param case: string used to define the title of the plot e.g. TP + :param ncols: number of cols the produced figure should have + :param size: size of the plot + :return: matplotlib figure of each tile in paths with attn score + """ + nrows = int(ceil(len(paths) / ncols)) + fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=size) + fig.suptitle(f"{case}: {slide} P=%.2f" % score) + for i in range(len(paths)): + img = load_pil_image(paths[i]) + axs.ravel()[i].imshow(img, clim=(0, 255), cmap='gray') + axs.ravel()[i].set_title("%.6f" % attn[i].cpu().item()) + for i in range(len(axs.ravel())): + axs.ravel()[i].set_axis_off() + return fig diff --git a/InnerEye/ML/Histopathology/utils/naming.py b/InnerEye/ML/Histopathology/utils/naming.py new file mode 100644 index 000000000..b1731ae73 --- /dev/null +++ b/InnerEye/ML/Histopathology/utils/naming.py @@ -0,0 +1,12 @@ +from enum import Enum + +class ResultsKey(str, Enum): + SLIDE_ID = 'slide_id' + TILE_ID = 'tile_id' + IMAGE = 'image' + IMAGE_PATH = 'image_path' + LOSS = 'loss' + PROB = 'prob' + PRED_LABEL = 'pred_label' + TRUE_LABEL = 'true_label' + BAG_ATTN = 'bag_attn' diff --git a/InnerEye/ML/Histopathology/utils/viz_utils.py b/InnerEye/ML/Histopathology/utils/viz_utils.py new file mode 100644 index 000000000..379559e88 --- /dev/null +++ b/InnerEye/ML/Histopathology/utils/viz_utils.py @@ -0,0 +1,53 @@ +import math +import matplotlib.pyplot as plt + +from monai.data.image_reader import WSIReader +from torch.utils.data import DataLoader + +from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset, LoadPandaROId + + +def load_image_dict(sample: dict, level: int, margin: int) -> dict: + """ + Load image from metadata dictionary + param sample: dict describing image metadata. Example: + {'image_id': ['1ca999adbbc948e69783686e5b5414e4'], + 'image': ['/tmp/datasets/PANDA/train_images/1ca999adbbc948e69783686e5b5414e4.tiff'], + 'mask': ['/tmp/datasets/PANDA/train_label_masks/1ca999adbbc948e69783686e5b5414e4_mask.tiff'], + 'data_provider': ['karolinska'], + 'isup_grade': tensor([0]), + 'gleason_score': ['0+0']} + param level: level of resolution to be loaded + param margin: margin to be included + return: a dict containing the image data and metadata + """ + loader = LoadPandaROId(WSIReader(), level=level, margin=margin) + img = loader(sample) + return img + + +def plot_panda_data_sample(panda_dir: str, nsamples: int, ncols: int, level: int, margin: int, + title_key: str = 'data_provider') -> None: + """ + param panda_dir: path to the dataset, it's expected a file called "train.csv" exists at the path. + Look at the PandaDataset for more detail + param nsamples: number of random samples to be visualized + param ncols: number of columns in the figure grid. Nrows is automatically inferred + param level: level of resolution to be loaded + param margin: margin to be included + param title_key: key in image_dict used to label each subplot + """ + panda_dataset = PandaDataset(root_dir=panda_dir, n_slides=nsamples) + loader = DataLoader(panda_dataset, batch_size=1) + + nrows = math.ceil(nsamples/ncols) + fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(9, 9)) + + for dict_images, ax in zip(loader, axes.flat): + slide_id = dict_images['image_id'] + title = dict_images[title_key] + print(f">>> Slide {slide_id}") + img = load_image_dict(dict_images, level=level, margin=margin) + ax.imshow(img['image'].transpose(1, 2, 0)) + ax.set_title(title) + fig.tight_layout() diff --git a/InnerEye/ML/baselines_util.py b/InnerEye/ML/baselines_util.py index 1e7cac9a4..8a0e3a724 100755 --- a/InnerEye/ML/baselines_util.py +++ b/InnerEye/ML/baselines_util.py @@ -191,6 +191,7 @@ def compare_files(expected: Path, actual: Path) -> str: If the files are not identical, an error message with details is return. This handles known text file formats, where it ignores differences in line breaks. All other files are treated as binary, and compared on a byte-by-byte basis. + :param expected: A file that contains the expected contents. The type of comparison (text or binary) is chosen based on the extension of this file. :param actual: A file that contains the actual contents. @@ -198,8 +199,9 @@ def compare_files(expected: Path, actual: Path) -> str: """ def print_lines(prefix: str, lines: List[str]) -> None: - count = 5 - logging.debug(f"{prefix} {len(lines)} lines, first {count} of those:") + num_lines = len(lines) + count = min(5, num_lines) + logging.debug(f"{prefix} {num_lines} lines, first {count} of those:") logging.debug(os.linesep.join(lines[:count])) if expected.suffix in TEXT_FILE_SUFFIXES: diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 6115c4629..304a6f9da 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -105,7 +105,7 @@ def create_lightning_trainer(container: LightningContainer, plugins = [] logging.info(f"Using {num_gpus} GPUs per node with accelerator '{accelerator}'") tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="") - loggers = [tensorboard_logger, AzureMLLogger()] + loggers = [tensorboard_logger, AzureMLLogger(False)] storing_logger = StoringLogger() loggers.append(storing_logger) # Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag. diff --git a/Tests/ML/histopathology/__init__.py b/Tests/ML/histopathology/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/Tests/ML/histopathology/datamodules/__init__.py b/Tests/ML/histopathology/datamodules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/Tests/ML/histopathology/datamodules/test_datamodule_caching.py b/Tests/ML/histopathology/datamodules/test_datamodule_caching.py new file mode 100644 index 000000000..71575bd68 --- /dev/null +++ b/Tests/ML/histopathology/datamodules/test_datamodule_caching.py @@ -0,0 +1,166 @@ +import shutil +from pathlib import Path +from typing import Any, Tuple + +import numpy as np +import pandas as pd +import pytest +import torch +from torch.utils.data import DataLoader + +from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, TilesDataModule +from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset + + +def noop_transform(x: Any) -> Any: + return x + + +def _check_generator_consistency(dl: DataLoader) -> None: + dataloader_generator = dl.generator + bag_sampler_generator = dl.dataset.data.bag_sampler.generator # type: ignore + assert torch.equal(dataloader_generator.get_state(), + bag_sampler_generator.get_state()) + + +def compare_dataloaders(dl1: DataLoader, dl2: DataLoader) -> None: + for batch1, batch2 in zip(dl1, dl2): + _check_generator_consistency(dl1) + _check_generator_consistency(dl2) + assert batch1.keys() == batch2.keys() + for key in batch1: + assert len(batch1[key]) == len(batch2[key]) + for item1, item2 in zip(batch1[key], batch2[key]): + if isinstance(item1, torch.Tensor): + assert torch.allclose(item1, item2, equal_nan=True) + else: + assert item1 == item2 + + +class MockTilesDataset(TilesDataset): + TILE_X_COLUMN = TILE_Y_COLUMN = None + TRAIN_SPLIT_LABEL = 'train' + VAL_SPLIT_LABEL = 'val' + TEST_SPLIT_LABEL = 'test' + + +def generate_mock_dataset_df(n_slides: int, n_tiles: int, n_classes: int) -> pd.DataFrame: + slide_ids = np.random.randint(n_slides, size=n_tiles) + slide_labels = np.random.randint(n_classes, size=n_slides) + tile_labels = slide_labels[slide_ids] + split_labels = [MockTilesDataset.TRAIN_SPLIT_LABEL, + MockTilesDataset.VAL_SPLIT_LABEL, + MockTilesDataset.TEST_SPLIT_LABEL] + slide_splits = np.random.choice(split_labels, size=n_slides) + tile_splits = slide_splits[slide_ids] + + df = pd.DataFrame() + df[MockTilesDataset.TILE_ID_COLUMN] = np.arange(n_tiles) + df[MockTilesDataset.SLIDE_ID_COLUMN] = slide_ids + df[MockTilesDataset.LABEL_COLUMN] = tile_labels + df[MockTilesDataset.SPLIT_COLUMN] = tile_splits + df[MockTilesDataset.IMAGE_COLUMN] = [f"{tile_splits[i]}/{i:06d}.png" for i in range(n_tiles)] + + return df + + +class MockTilesDataModule(TilesDataModule): + def get_splits(self) -> Tuple[MockTilesDataset, MockTilesDataset, MockTilesDataset]: + df = MockTilesDataset(self.root_path).dataset_df + df = df.reset_index() + split_dfs = (df[df[MockTilesDataset.SPLIT_COLUMN] == MockTilesDataset.TRAIN_SPLIT_LABEL], + df[df[MockTilesDataset.SPLIT_COLUMN] == MockTilesDataset.VAL_SPLIT_LABEL], + df[df[MockTilesDataset.SPLIT_COLUMN] == MockTilesDataset.TEST_SPLIT_LABEL]) + return tuple(MockTilesDataset(self.root_path, dataset_df=split_df) # type: ignore + for split_df in split_dfs) + + +@pytest.fixture +def mock_data_dir(tmp_path: Path) -> Path: + csv_dir = tmp_path / "mock_tiles_dataset" + csv_dir.mkdir(exist_ok=True) + csv_path = csv_dir / MockTilesDataset.DEFAULT_CSV_FILENAME + if not csv_path.exists(): + csv_path.parent.mkdir(parents=True, exist_ok=True) + df = generate_mock_dataset_df(n_slides=8, n_tiles=100, n_classes=2) + df.to_csv(csv_path, index=False) + return csv_dir + +def _get_datamodule(cache_mode: CacheMode, save_precache: bool, + cache_dir_provided: bool, data_dir: Path) -> TilesDataModule: + if (cache_mode is CacheMode.NONE and save_precache) \ + or (cache_mode is CacheMode.DISK and not cache_dir_provided) \ + or (save_precache and not cache_dir_provided): + pytest.skip("Unsupported combination of caching arguments") + + cache_dir = data_dir / f"datamodule_cache_{cache_mode.value}" if cache_dir_provided else None + + if cache_dir is not None and cache_dir.exists(): + shutil.rmtree(cache_dir) + + return MockTilesDataModule(root_path=data_dir, + transform=noop_transform, + seed=0, + batch_size=2, + cache_mode=cache_mode, + save_precache=save_precache, + cache_dir=cache_dir) + + +@pytest.mark.parametrize('cache_mode', [CacheMode.MEMORY, CacheMode.DISK, CacheMode.NONE]) +@pytest.mark.parametrize('save_precache', [True, False]) +@pytest.mark.parametrize('cache_dir_provided', [True, False]) +def test_caching_consistency(mock_data_dir: Path, cache_mode: CacheMode, save_precache: bool, + cache_dir_provided: bool) -> None: + # Compare two dataloaders from the same datamodule + datamodule = _get_datamodule(cache_mode=cache_mode, + save_precache=save_precache, + cache_dir_provided=cache_dir_provided, + data_dir=mock_data_dir) + datamodule.prepare_data() + train_dataloader = datamodule.train_dataloader() + train_dataloader2 = datamodule.train_dataloader() + + compare_dataloaders(train_dataloader, train_dataloader2) + + # Compare datamodules reusing the same cache + datamodule = _get_datamodule(cache_mode=cache_mode, + save_precache=save_precache, + cache_dir_provided=cache_dir_provided, + data_dir=mock_data_dir) + datamodule.prepare_data() + train_dataloader = datamodule.train_dataloader() + + reloaded_datamodule = _get_datamodule(cache_mode=cache_mode, + save_precache=save_precache, + cache_dir_provided=cache_dir_provided, + data_dir=mock_data_dir) + reloaded_datamodule.prepare_data() + reloaded_train_dataloader = reloaded_datamodule.train_dataloader() + + compare_dataloaders(train_dataloader, reloaded_train_dataloader) + + +@pytest.mark.parametrize('cache_mode', [CacheMode.MEMORY, CacheMode.DISK, CacheMode.NONE]) +@pytest.mark.parametrize('save_precache', [True, False]) +@pytest.mark.parametrize('cache_dir_provided', [True, False]) +def test_tile_id_coverage(mock_data_dir: Path, cache_mode: CacheMode, save_precache: bool, + cache_dir_provided: bool) -> None: + datamodule = _get_datamodule(cache_mode=cache_mode, + save_precache=save_precache, + cache_dir_provided=cache_dir_provided, + data_dir=mock_data_dir) + datamodule.prepare_data() + train_dataset = datamodule.train_dataset + train_dataloader = datamodule.train_dataloader() + expected_tile_ids = set(train_dataset.dataset_df.index) + loaded_tile_ids = set() # type: ignore + for batch in train_dataloader: + for stacked_bag_tile_ids in batch[train_dataset.TILE_ID_COLUMN]: + if isinstance(stacked_bag_tile_ids, torch.Tensor): + stacked_bag_tile_ids = stacked_bag_tile_ids.tolist() + bag_tile_ids = set(stacked_bag_tile_ids) + assert bag_tile_ids.isdisjoint(loaded_tile_ids), \ + f"Tile IDs already seen: {bag_tile_ids}" + loaded_tile_ids.update(bag_tile_ids) + assert loaded_tile_ids == expected_tile_ids diff --git a/Tests/ML/histopathology/models/__init__.py b/Tests/ML/histopathology/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/Tests/ML/histopathology/models/test_encoders.py b/Tests/ML/histopathology/models/test_encoders.py new file mode 100644 index 000000000..7c102e7d9 --- /dev/null +++ b/Tests/ML/histopathology/models/test_encoders.py @@ -0,0 +1,40 @@ +from typing import Callable + +import pytest +from torch import Tensor, float32, nn, rand +from torchvision.models import resnet18 + +from InnerEye.ML.Histopathology.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder, + ImageNetSimCLREncoder) + + +def get_supervised_imagenet_encoder() -> TileEncoder: + return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224) + + +def get_simclr_imagenet_encoder() -> TileEncoder: + return ImageNetSimCLREncoder(tile_size=224) + + +def get_histo_ssl_encoder() -> TileEncoder: + return HistoSSLEncoder(tile_size=224) + + +@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder, + get_simclr_imagenet_encoder, + get_histo_ssl_encoder]) +def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None: + batch_size = 10 + + encoder = create_encoder_fn() + + if isinstance(encoder, nn.Module): + for param_name, param in encoder.named_parameters(): + assert not param.requires_grad, \ + f"Feature extractor has unfrozen parameters: {param_name}" + + images = rand(batch_size, *encoder.input_dim, dtype=float32) + + features = encoder(images) + assert isinstance(features, Tensor) + assert features.shape == (batch_size, encoder.num_encoding) diff --git a/Tests/ML/histopathology/preprocessing/__init__.py b/Tests/ML/histopathology/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/Tests/ML/histopathology/preprocessing/test_tiling.py b/Tests/ML/histopathology/preprocessing/test_tiling.py new file mode 100644 index 000000000..891ac0029 --- /dev/null +++ b/Tests/ML/histopathology/preprocessing/test_tiling.py @@ -0,0 +1,121 @@ +import numpy as np +import pytest + +from InnerEye.ML.Histopathology.preprocessing.tiling import assemble_tiles_2d, get_1d_padding, \ + pad_for_tiling_2d, tile_array_2d + + +@pytest.mark.parametrize("length,tile_size", + [(8, 4), (9, 4), (8, 3), (4, 4), (3, 4)]) +def test_1d_padding(length: int, tile_size: int) -> None: + pad_pre, pad_post = get_1d_padding(length, tile_size) + + assert pad_pre >= 0 and pad_post >= 0 + assert pad_pre < tile_size and pad_post < tile_size + assert abs(pad_post - pad_pre) <= 1, "Asymmetric padding" + + padded_length = pad_pre + length + pad_post + assert padded_length % tile_size == 0 + + n_tiles = padded_length // tile_size + expected_n_tiles = int(np.ceil(length / tile_size)) + assert n_tiles == expected_n_tiles + + +@pytest.mark.parametrize("width,height", [(8, 6)]) +@pytest.mark.parametrize("tile_size", [3, 4, 5]) +@pytest.mark.parametrize("channels_first", [True, False]) +def test_2d_padding(width: int, height: int, tile_size: int, channels_first: bool) -> None: + channels = 2 + pad_value = 0 + array = np.random.rand(channels, height, width) + + input_array = array if channels_first else array.transpose(1, 2, 0) + padded_array, (offset_w, offset_h) = pad_for_tiling_2d(input_array, tile_size, channels_first, + constant_values=pad_value) + if not channels_first: + padded_array = padded_array.transpose(2, 0, 1) + + padded_channels, padded_height, padded_width = padded_array.shape + assert padded_channels == channels and padded_height >= height and padded_width >= width + assert padded_height % tile_size == 0 and padded_width % tile_size == 0 + assert 0 <= offset_h < tile_size and 0 <= offset_w < tile_size + + crop = padded_array[:, offset_h:offset_h + height, offset_w:offset_w + width] + assert np.array_equal(crop, array) + + # np.array_equiv() broadcasts the shapes + assert np.array_equiv(padded_array[:, :offset_h, :], pad_value) + assert np.array_equiv(padded_array[:, :, :offset_w], pad_value) + assert np.array_equiv(padded_array[:, offset_h + height:, :], pad_value) + assert np.array_equiv(padded_array[:, :, offset_w + width:], pad_value) + + +def _get_2d_meshgrid(width: int, height: int, channels_first: bool = True) -> np.ndarray: + array = np.stack(np.meshgrid(np.arange(width), np.arange(height)), + axis=0 if channels_first else -1) + assert array.shape == ((2, height, width) if channels_first else (height, width, 2)) + return array + + +@pytest.mark.parametrize("width,height", [(8, 6)]) +@pytest.mark.parametrize("tile_size", [3, 4, 5]) +@pytest.mark.parametrize("channels_first", [True, False]) +def test_tile_array_2d_both(width: int, height: int, tile_size: int, channels_first: bool) -> None: + channels = 2 + array = _get_2d_meshgrid(width, height, channels_first) + + padded_array, (offset_w, offset_h) = pad_for_tiling_2d(array, tile_size, channels_first, + constant_values=0) + + tiles, coords = tile_array_2d(array, tile_size, channels_first) + assert tiles.shape[0] == coords.shape[0] + + expected_n_tiles_w = int(np.ceil(width / tile_size)) + expected_n_tiles_h = int(np.ceil(height / tile_size)) + expected_n_tiles = expected_n_tiles_w * expected_n_tiles_h + + if channels_first: + assert tiles.shape == (expected_n_tiles, channels, tile_size, tile_size) + else: + assert tiles.shape == (expected_n_tiles, tile_size, tile_size, channels) + assert coords.shape == (expected_n_tiles, 2) + + for idx in range(tiles.shape[0]): + row = coords[idx, 1] + offset_h + col = coords[idx, 0] + offset_w + if channels_first: + expected_tile = padded_array[:, row:row + tile_size, col:col + tile_size] + else: + expected_tile = padded_array[row:row + tile_size, col:col + tile_size, :] + assert np.array_equal(tiles[idx], expected_tile) + + expected_x = tile_size * (idx % expected_n_tiles_w) - offset_w + expected_y = tile_size * (idx // expected_n_tiles_w) - offset_h + assert tuple(coords[idx]) == (expected_x, expected_y) + + +@pytest.mark.parametrize("width,height", [(8, 6)]) +@pytest.mark.parametrize("tile_size", [3, 4, 5]) +@pytest.mark.parametrize("channels_first", [True, False]) +def test_assemble_tiles_2d(width: int, height: int, tile_size: int, channels_first: bool) -> None: + array = _get_2d_meshgrid(width, height, channels_first) + fill_value = 0 + padded_array, padding_offset = pad_for_tiling_2d(array, tile_size, channels_first, + constant_values=fill_value) + + tiles, coords = tile_array_2d(array, tile_size, channels_first) + + assembled_array, assembly_offset = assemble_tiles_2d(tiles, coords, fill_value=fill_value, + channels_first=channels_first) + assert np.array_equal(assembled_array, padded_array) + assert np.array_equal(assembly_offset, padding_offset) + + for idx in range(tiles.shape[0]): + row = coords[idx, 1] + assembly_offset[1] + col = coords[idx, 0] + assembly_offset[0] + if channels_first: + crop = assembled_array[:, row:row + tile_size, col:col + tile_size] + else: + crop = assembled_array[row:row + tile_size, col:col + tile_size, :] + assert np.array_equal(crop, tiles[idx]) diff --git a/Tests/ML/histopathology/utils/__init__.py b/Tests/ML/histopathology/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py new file mode 100644 index 000000000..aacaf2ab5 --- /dev/null +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -0,0 +1,58 @@ + +import math +from typing import List + +import matplotlib +from torch.functional import Tensor + +from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles +from InnerEye.ML.Histopathology.utils.naming import ResultsKey + + +def assert_equal_lists(pred: List, expected: List) -> None: + assert len(pred) == len(expected) + for i, slide in enumerate(pred): + for j, value in enumerate(slide): + if type(value) in [int, float]: + assert math.isclose(value, expected[i][j], rel_tol=1e-06) + elif isinstance(value, List): + for k, idx in enumerate(value): + if type(idx) in [int, float]: + assert math.isclose(idx, expected[i][j][k], rel_tol=1e-06) + elif type(idx) == Tensor: + assert math.isclose(idx.item(), expected[i][j][k].item(), rel_tol=1e-06) + else: + raise TypeError("Unexpected list composition") + + +test_dict = {ResultsKey.SLIDE_ID: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], + ResultsKey.IMAGE_PATH: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], + ResultsKey.PROB: [Tensor([0.5]), Tensor([0.7]), Tensor([0.4]), Tensor([1.0])], + ResultsKey.TRUE_LABEL: [0, 1, 1, 1], + ResultsKey.BAG_ATTN: + [Tensor([[0.1, 0.0, 0.2, 0.15]]), + Tensor([[0.10, 0.18, 0.15, 0.13]]), + Tensor([[0.25, 0.23, 0.20, 0.21]]), + Tensor([[0.33, 0.31, 0.37, 0.35]])] + } + +def test_select_k_tiles() -> None: + top_tn = select_k_tiles(test_dict, n_slides=1, label=0, n_tiles=2, select=('lowest_pred', 'highest_att')) + assert_equal_lists(top_tn, [(1, 0.5, [3, 4], [Tensor([0.2]), Tensor([0.15])])]) + + nslides = 2 + ntiles = 2 + top_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'highest_att')) + bottom_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'lowest_att')) + assert_equal_lists(top_fn, [(3, 0.4, [1, 2], [Tensor([0.25]), Tensor([0.23])]), (2, 0.7, [2, 3], [Tensor([0.18]), Tensor([0.15])])]) + assert_equal_lists(bottom_fn, [(3, 0.4, [3, 4], [Tensor([0.20]), Tensor([0.21])]), (2, 0.7, [1, 4], [Tensor([0.10]), Tensor([0.13])])]) + + top_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'highest_att')) + bottom_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'lowest_att')) + assert_equal_lists(top_tp, [(4, 1.0, [3, 4], [Tensor([0.37]), Tensor([0.35])]), (2, 0.7, [2, 3], [Tensor([0.18]), Tensor([0.15])])]) + assert_equal_lists(bottom_tp, [(4, 1.0, [2, 1], [Tensor([0.31]), Tensor([0.33])]), (2, 0.7, [1, 4], [Tensor([0.10]), Tensor([0.13])])]) + + +def test_plot_scores_hist() -> None: + fig = plot_scores_hist(test_dict) + assert isinstance(fig, matplotlib.figure.Figure) diff --git a/environment.yml b/environment.yml index f37d86c08..19fbbe3db 100644 --- a/environment.yml +++ b/environment.yml @@ -14,6 +14,7 @@ dependencies: - git+https://github.com/analysiscenter/radio.git@6d53e25#egg=radio - azure-mgmt-resource==12.1.0 - azure-mgmt-datafactory==1.1.0 + - azure-storage-blob==12.6.0 - azureml-mlflow==1.36.0 - azureml-sdk==1.36.0 - azureml-tensorboard==1.36.0 @@ -34,6 +35,7 @@ dependencies: - lightning-bolts==0.3.4 - matplotlib==3.3.0 - mlflow==1.17.0 + - monai==0.6.0 - mypy==0.910 - mypy-extensions==0.4.3 - numba==0.51.2 @@ -68,4 +70,5 @@ dependencies: - tensorboardX==2.1 - torchprof==1.3.3 - torchmetrics==0.4.1 + - umap-learn==0.5.2 - yacs==0.1.8 diff --git a/hi-ml b/hi-ml new file mode 160000 index 000000000..334186321 --- /dev/null +++ b/hi-ml @@ -0,0 +1 @@ +Subproject commit 334186321f6989033f5609880781ba4c299f6f67