diff --git a/CHANGELOG.md b/CHANGELOG.md index de0ae157a..4bedef1db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ jobs that run in AzureML. - ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL - ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup. - ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task. +- ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling. ### Changed - ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3. diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index afd88c17f..fc30c2ec3 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -204,7 +204,7 @@ def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore with set_grad_enabled(self.is_finetune): instance_features = self.encoder(instances) # N X L x 1 x 1 attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L - bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim) + bag_features = bag_features.view(1, -1) bag_logit = self.classifier_fn(bag_features) return bag_logit, attentions diff --git a/InnerEye/ML/Histopathology/models/transforms.py b/InnerEye/ML/Histopathology/models/transforms.py index 1d61fdcb3..6e9b5aa79 100644 --- a/InnerEye/ML/Histopathology/models/transforms.py +++ b/InnerEye/ML/Histopathology/models/transforms.py @@ -10,7 +10,7 @@ import numpy as np import PIL from monai.config.type_definitions import KeysCollection -from monai.transforms.transform import MapTransform +from monai.transforms.transform import MapTransform, Randomizable from torchvision.transforms.functional import to_tensor from InnerEye.ML.Histopathology.models.encoders import TileEncoder @@ -92,7 +92,7 @@ def __init__(self, allow_missing_keys: bool = False, chunk_size: int = 0) -> None: """ - :param keys: Key(s) for the image path(s) in the input dictionary. + :param keys: Key(s) for the image tensor(s) in the input dictionary. :param encoder: The tile encoder to use for feature extraction. :param allow_missing_keys: If `False` (default), raises an exception when an input dictionary is missing any of the specified keys. @@ -128,3 +128,42 @@ def __call__(self, data: Mapping) -> Mapping: for key in self.key_iterator(out_data): out_data[key] = self._encode_tiles(data[key]) return out_data + + +def take_indices(data: Sequence, indices: np.ndarray) -> Sequence: + if isinstance(data, (np.ndarray, torch.Tensor)): + return data[indices] + elif isinstance(data, Sequence): + return [data[i] for i in indices] + else: + raise ValueError(f"Data of type {type(data)} is not indexable") + + +class Subsampled(MapTransform, Randomizable): + """Dictionary transform to randomly subsample the data down to a fixed maximum length""" + + def __init__(self, keys: KeysCollection, max_size: int, + allow_missing_keys: bool = False) -> None: + """ + :param keys: Key(s) for all batch elements that must be subsampled. + :param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at + random down to `max_size` along their first dimension. If shorter, the elements are merely + shuffled. + :param allow_missing_keys: If `False` (default), raises an exception when an input + dictionary is missing any of the specified keys. + """ + super().__init__(keys, allow_missing_keys=allow_missing_keys) + self.max_size = max_size + self._indices: np.ndarray + + def randomize(self, total_size: int) -> None: + subsample_size = min(self.max_size, total_size) + self._indices = self.R.choice(total_size, size=subsample_size) + + def __call__(self, data: Mapping) -> Mapping: + out_data = dict(data) # create shallow copy + size = len(data[self.keys[0]]) + self.randomize(size) + for key in self.key_iterator(out_data): + out_data[key] = take_indices(data[key], self._indices) + return out_data diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 4406a3de2..9b2b3df37 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -14,7 +14,7 @@ from torch import nn from torchvision.models.resnet import resnet18 -from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer +from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer from InnerEye.ML.lightning_container import LightningContainer from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule @@ -90,6 +90,8 @@ def get_pooling_layer(self) -> Type[nn.Module]: return AttentionLayer elif self.pooling_type == GatedAttentionLayer.__name__: return GatedAttentionLayer + elif self.pooling_type == MeanPoolingLayer.__name__: + return MeanPoolingLayer else: raise ValueError(f"Unsupported pooling type: {self.pooling_type}") diff --git a/Tests/ML/histopathology/models/test_deepmil.py b/Tests/ML/histopathology/models/test_deepmil.py index 002ca0c72..e71e1e3d6 100644 --- a/Tests/ML/histopathology/models/test_deepmil.py +++ b/Tests/ML/histopathology/models/test_deepmil.py @@ -14,6 +14,7 @@ from health_ml.networks.layers.attention_layers import ( AttentionLayer, GatedAttentionLayer, + MeanPoolingLayer, ) from InnerEye.ML.lightning_container import LightningContainer @@ -37,14 +38,7 @@ def get_supervised_imagenet_encoder() -> TileEncoder: return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224) -@pytest.mark.parametrize("n_classes", [1, 3]) -@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer]) -@pytest.mark.parametrize("batch_size", [1, 15]) -@pytest.mark.parametrize("max_bag_size", [1, 7]) -@pytest.mark.parametrize("pool_hidden_dim", [1, 5]) -@pytest.mark.parametrize("pool_out_dim", [1, 6]) -@pytest.mark.parametrize("dropout_rate", [None, 0.5]) -def test_lightningmodule( +def _test_lightningmodule( n_classes: int, pooling_layer: Callable[[int, int, int], nn.Module], batch_size: int, @@ -119,6 +113,50 @@ def test_lightningmodule( assert torch.all(score <= 1) +@pytest.mark.parametrize("n_classes", [1, 3]) +@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer]) +@pytest.mark.parametrize("batch_size", [1, 15]) +@pytest.mark.parametrize("max_bag_size", [1, 7]) +@pytest.mark.parametrize("pool_hidden_dim", [1, 5]) +@pytest.mark.parametrize("pool_out_dim", [1, 6]) +@pytest.mark.parametrize("dropout_rate", [None, 0.5]) +def test_lightningmodule_attention( + n_classes: int, + pooling_layer: Callable[[int, int, int], nn.Module], + batch_size: int, + max_bag_size: int, + pool_hidden_dim: int, + pool_out_dim: int, + dropout_rate: Optional[float], +) -> None: + _test_lightningmodule(n_classes=n_classes, + pooling_layer=pooling_layer, + batch_size=batch_size, + max_bag_size=max_bag_size, + pool_hidden_dim=pool_hidden_dim, + pool_out_dim=pool_out_dim, + dropout_rate=dropout_rate) + + +@pytest.mark.parametrize("n_classes", [1, 3]) +@pytest.mark.parametrize("batch_size", [1, 15]) +@pytest.mark.parametrize("max_bag_size", [1, 7]) +@pytest.mark.parametrize("dropout_rate", [None, 0.5]) +def test_lightningmodule_mean_pooling( + n_classes: int, + batch_size: int, + max_bag_size: int, + dropout_rate: Optional[float], +) -> None: + _test_lightningmodule(n_classes=n_classes, + pooling_layer=MeanPoolingLayer, + batch_size=batch_size, + max_bag_size=max_bag_size, + pool_hidden_dim=1, + pool_out_dim=1, + dropout_rate=dropout_rate) + + def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict: device = "cuda" if use_gpu else "cpu" return { diff --git a/Tests/ML/histopathology/models/test_transforms.py b/Tests/ML/histopathology/models/test_transforms.py index 17aa8c60b..6852aaaf0 100644 --- a/Tests/ML/histopathology/models/test_transforms.py +++ b/Tests/ML/histopathology/models/test_transforms.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Callable, Sequence, Union +import numpy as np import pytest import torch from monai.data.dataset import CacheDataset, Dataset, PersistentDataset @@ -19,7 +20,7 @@ from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder -from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd +from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled from Tests.ML.util import assert_dicts_equal @@ -153,3 +154,57 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None: bagged_subset, transform=transform, cache_subdir="TCGA-CRCk_embed_cache") + + +@pytest.mark.parametrize('include_non_indexable', [True, False]) +@pytest.mark.parametrize('allow_missing_keys', [True, False]) +def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None: + batch_size = 5 + max_size = batch_size // 2 + data = { + 'array_1d': np.random.randn(batch_size), + 'array_2d': np.random.randn(batch_size, 4), + 'tensor_1d': torch.randn(batch_size), + 'tensor_2d': torch.randn(batch_size, 4), + 'list': torch.randn(batch_size).tolist(), + 'indices': list(range(batch_size)), + 'non-indexable': 42, + } + + keys_to_subsample = list(data.keys()) + if not include_non_indexable: + keys_to_subsample.remove('non-indexable') + keys_to_subsample.append('missing-key') + + subsampling = Subsampled(keys_to_subsample, max_size=max_size, + allow_missing_keys=allow_missing_keys) + + if include_non_indexable: + with pytest.raises(ValueError): + sub_data = subsampling(data) + return + elif not allow_missing_keys: + with pytest.raises(KeyError): + sub_data = subsampling(data) + return + else: + sub_data = subsampling(data) + + assert set(sub_data.keys()) == set(data.keys()) + + # Check lenghts before and after subsampling + for key in keys_to_subsample: + if key not in data: + continue # Skip missing keys + assert len(data[key]) == batch_size # type: ignore + assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore + + # Check contents of subsampled elements + for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']: + for idx, elem in zip(sub_data['indices'], sub_data[key]): + assert np.array_equal(elem, data[key][idx]) # type: ignore + + # Check that subsampling is random, i.e. subsequent calls shouldn't give identical results + sub_data2 = subsampling(data) + for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']: + assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore diff --git a/hi-ml b/hi-ml index a33c1ed07..2bc397b47 160000 --- a/hi-ml +++ b/hi-ml @@ -1 +1 @@ -Subproject commit a33c1ed07da8a42486dec9f939cd59eea4b2583e +Subproject commit 2bc397b4707b56fecca624ce81e6883e0170b24b