From 845910a6e595b23fcc0284953b960ea31a4be1a2 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Tue, 1 Feb 2022 16:31:24 +0000 Subject: [PATCH 01/15] Add dropout to DeepMILModule --- InnerEye/ML/Histopathology/models/deepmil.py | 22 ++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 14ca78c00..a30acd37d 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -46,6 +46,7 @@ def __init__(self, pooling_layer: Callable[[int, int, int], nn.Module], pool_hidden_dim: int = 128, pool_out_dim: int = 1, + dropout_rate: Optional[float] = None, class_weights: Optional[Tensor] = None, l_rate: float = 5e-4, weight_decay: float = 1e-4, @@ -63,6 +64,7 @@ def __init__(self, `torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions. :param pool_hidden_dim: Hidden dimension of pooling layer (default=128). :param pool_out_dim: Output dimension of pooling layer (default=1). + :param dropout_rate: Rate of pre-classifier dropout. `None` for no dropout (default). :param class_weights: Tensor containing class weights (default=None). :param l_rate: Optimiser learning rate. :param weight_decay: Weight decay parameter for L2 regularisation. @@ -80,6 +82,7 @@ def __init__(self, self.pool_hidden_dim = pool_hidden_dim self.pool_out_dim = pool_out_dim self.pooling_layer = pooling_layer + self.dropout_rate = dropout_rate self.class_weights = class_weights self.encoder = encoder self.num_encoding = self.encoder.num_encoding @@ -99,6 +102,7 @@ def __init__(self, self.verbose = verbose self.aggregation_fn, self.num_pooling = self.get_pooling() + self.dropout = self.get_dropout() self.classifier_fn = self.get_classifier() self.loss_fn = self.get_loss() self.activation_fn = self.get_activation() @@ -115,6 +119,11 @@ def get_pooling(self) -> Tuple[Callable, int]: num_features = self.num_encoding*self.pool_out_dim return pooling_layer, num_features + def get_dropout(self) -> Callable: + if self.dropout_rate is None: + return nn.Identity() + return nn.Dropout(self.dropout_rate) + def get_classifier(self) -> Callable: return nn.Linear(in_features=self.num_pooling, out_features=self.n_classes) @@ -164,13 +173,14 @@ def log_metrics(self, for metric_name, metric_object in self.get_metrics_dict(stage).items(): self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) - def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore + def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore with no_grad(): - H = self.encoder(images) # N X L x 1 x 1 - A, M = self.aggregation_fn(H) # A: K x N | M: K x L - M = M.view(-1, self.num_encoding * self.pool_out_dim) - Y_prob = self.classifier_fn(M) - return Y_prob, A + instance_features = self.encoder(instances) # N X L x 1 x 1 + attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L + bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim) + bag_features_dropout = self.dropout(bag_features) + bag_logit = self.classifier_fn(bag_features_dropout) + return bag_logit, attentions def configure_optimizers(self) -> optim.Optimizer: return optim.Adam(self.parameters(), lr=self.l_rate, weight_decay=self.weight_decay, From 34c2089f55a983690b3a71c3b8b200457dbffc99 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Tue, 1 Feb 2022 16:36:23 +0000 Subject: [PATCH 02/15] Fix feature extractor setup for torchvision models --- .../ML/Histopathology/utils/layer_utils.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/InnerEye/ML/Histopathology/utils/layer_utils.py b/InnerEye/ML/Histopathology/utils/layer_utils.py index d3b88d3c0..50fe99fdc 100644 --- a/InnerEye/ML/Histopathology/utils/layer_utils.py +++ b/InnerEye/ML/Histopathology/utils/layer_utils.py @@ -5,7 +5,7 @@ from typing import Callable, Tuple -from torch import as_tensor, device, nn, prod, rand +from torch import as_tensor, device, nn, no_grad, prod, rand from torch.hub import load_state_dict_from_url from torchvision.transforms import Normalize @@ -16,14 +16,22 @@ def get_imagenet_preprocessing() -> nn.Module: def setup_feature_extractor(pretrained_model: nn.Module, input_dim: Tuple[int, int, int]) -> Tuple[Callable, int]: - layers = list(pretrained_model.children())[:-1] - layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps - feature_extractor = nn.Sequential(*layers) + try: + # Attempt to auto-detect final classification layer: + num_features: int = pretrained_model.fc.in_features # type: ignore + pretrained_model.fc = nn.Flatten() + feature_extractor = pretrained_model + except AttributeError: + # Otherwise fallback to sequence of child modules: + layers = list(pretrained_model.children())[:-1] + layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps + feature_extractor = nn.Sequential(*layers) + with no_grad(): + feature_shape = feature_extractor(rand(1, *input_dim)).shape + num_features = int(prod(as_tensor(feature_shape)).item()) # fix weights, no fine-tuning for param in feature_extractor.parameters(): param.requires_grad = False - feature_shape = feature_extractor(rand(1, *input_dim)).shape - num_features = int(prod(as_tensor(feature_shape)).item()) return feature_extractor, num_features From 31961cb78e74e3e268c5b583c5fb9eab28ff3090 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 4 Feb 2022 12:45:55 +0000 Subject: [PATCH 03/15] Refactor DeepMIL dropout into classifier --- InnerEye/ML/Histopathology/models/deepmil.py | 21 ++++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index a30acd37d..ff903bc09 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -64,7 +64,7 @@ def __init__(self, `torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions. :param pool_hidden_dim: Hidden dimension of pooling layer (default=128). :param pool_out_dim: Output dimension of pooling layer (default=1). - :param dropout_rate: Rate of pre-classifier dropout. `None` for no dropout (default). + :param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default). :param class_weights: Tensor containing class weights (default=None). :param l_rate: Optimiser learning rate. :param weight_decay: Weight decay parameter for L2 regularisation. @@ -102,7 +102,6 @@ def __init__(self, self.verbose = verbose self.aggregation_fn, self.num_pooling = self.get_pooling() - self.dropout = self.get_dropout() self.classifier_fn = self.get_classifier() self.loss_fn = self.get_loss() self.activation_fn = self.get_activation() @@ -119,14 +118,15 @@ def get_pooling(self) -> Tuple[Callable, int]: num_features = self.num_encoding*self.pool_out_dim return pooling_layer, num_features - def get_dropout(self) -> Callable: - if self.dropout_rate is None: - return nn.Identity() - return nn.Dropout(self.dropout_rate) - def get_classifier(self) -> Callable: - return nn.Linear(in_features=self.num_pooling, - out_features=self.n_classes) + classifier_layer = nn.Linear(in_features=self.num_pooling, + out_features=self.n_classes) + if self.dropout_rate is None: + return classifier_layer + elif 0 <= self.dropout_rate < 1: + return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer) + else: + raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}") def get_loss(self) -> Callable: if self.n_classes > 1: @@ -178,8 +178,7 @@ def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore instance_features = self.encoder(instances) # N X L x 1 x 1 attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim) - bag_features_dropout = self.dropout(bag_features) - bag_logit = self.classifier_fn(bag_features_dropout) + bag_logit = self.classifier_fn(bag_features) return bag_logit, attentions def configure_optimizers(self) -> optim.Optimizer: From 8de20ec262c82f852c009b4dde41febef8304242 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 4 Feb 2022 14:26:39 +0000 Subject: [PATCH 04/15] Refactor and add tests for feature extractor setup --- .../ML/Histopathology/utils/layer_utils.py | 4 +- .../ML/histopathology/models/test_encoders.py | 53 ++++++++++++++----- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/InnerEye/ML/Histopathology/utils/layer_utils.py b/InnerEye/ML/Histopathology/utils/layer_utils.py index 50fe99fdc..842193563 100644 --- a/InnerEye/ML/Histopathology/utils/layer_utils.py +++ b/InnerEye/ML/Histopathology/utils/layer_utils.py @@ -3,7 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Callable, Tuple +from typing import Tuple from torch import as_tensor, device, nn, no_grad, prod, rand from torch.hub import load_state_dict_from_url @@ -15,7 +15,7 @@ def get_imagenet_preprocessing() -> nn.Module: def setup_feature_extractor(pretrained_model: nn.Module, - input_dim: Tuple[int, int, int]) -> Tuple[Callable, int]: + input_dim: Tuple[int, int, int]) -> Tuple[nn.Module, int]: try: # Attempt to auto-detect final classification layer: num_features: int = pretrained_model.fc.in_features # type: ignore diff --git a/Tests/ML/histopathology/models/test_encoders.py b/Tests/ML/histopathology/models/test_encoders.py index a9ad82864..29fefc31f 100644 --- a/Tests/ML/histopathology/models/test_encoders.py +++ b/Tests/ML/histopathology/models/test_encoders.py @@ -3,43 +3,68 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Callable +from typing import Callable, Tuple +import numpy as np import pytest from torch import Tensor, float32, nn, rand from torchvision.models import resnet18 from InnerEye.ML.Histopathology.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder, ImageNetSimCLREncoder) +from InnerEye.ML.Histopathology.utils.layer_utils import setup_feature_extractor + +TILE_SIZE = 224 +INPUT_DIMS = (3, TILE_SIZE, TILE_SIZE) def get_supervised_imagenet_encoder() -> TileEncoder: - return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224) + return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=TILE_SIZE) def get_simclr_imagenet_encoder() -> TileEncoder: - return ImageNetSimCLREncoder(tile_size=224) + return ImageNetSimCLREncoder(tile_size=TILE_SIZE) def get_histo_ssl_encoder() -> TileEncoder: - return HistoSSLEncoder(tile_size=224) - - -@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder, - get_simclr_imagenet_encoder, - get_histo_ssl_encoder]) -def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None: - batch_size = 10 + return HistoSSLEncoder(tile_size=TILE_SIZE) - encoder = create_encoder_fn() +def _test_encoder(encoder: nn.Module, input_dims: Tuple[int, ...], output_dim: int, + batch_size: int = 5) -> None: if isinstance(encoder, nn.Module): for param_name, param in encoder.named_parameters(): assert not param.requires_grad, \ f"Feature extractor has unfrozen parameters: {param_name}" - images = rand(batch_size, *encoder.input_dim, dtype=float32) + images = rand(batch_size, *input_dims, dtype=float32) features = encoder(images) assert isinstance(features, Tensor) - assert features.shape == (batch_size, encoder.num_encoding) + assert features.shape == (batch_size, output_dim) + + +@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder, + get_simclr_imagenet_encoder, + get_histo_ssl_encoder]) +def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None: + encoder = create_encoder_fn() + _test_encoder(encoder, input_dims=encoder.input_dim, output_dim=encoder.num_encoding) + + +def _dummy_classifier() -> nn.Module: + input_size = np.prod(INPUT_DIMS) + hidden_dim = 10 + return nn.Sequential( + nn.Flatten(), + nn.Linear(input_size, hidden_dim), + nn.Tanh(), + nn.Linear(hidden_dim, 1) + ) + + +@pytest.mark.parametrize('create_classifier_fn', [resnet18, _dummy_classifier]) +def test_setup_feature_extractor(create_classifier_fn: Callable[[], nn.Module]) -> None: + classifier = create_classifier_fn() + encoder, num_features = setup_feature_extractor(classifier, INPUT_DIMS) + _test_encoder(encoder, input_dims=INPUT_DIMS, output_dim=num_features) From 2541eee750a778076d143a083ef286d19c85f816 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 4 Feb 2022 14:26:52 +0000 Subject: [PATCH 05/15] Add dropout to DeepMIL tests --- Tests/ML/histopathology/models/test_deepmil.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Tests/ML/histopathology/models/test_deepmil.py b/Tests/ML/histopathology/models/test_deepmil.py index 797b3080c..b030915e6 100644 --- a/Tests/ML/histopathology/models/test_deepmil.py +++ b/Tests/ML/histopathology/models/test_deepmil.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------------------ import os -from typing import Callable, Dict, List, Type # noqa +from typing import Callable, Dict, List, Optional, Type # noqa import pytest from torch import Tensor, argmax, nn, rand, randint, randn, round, stack, allclose @@ -42,6 +42,7 @@ def get_supervised_imagenet_encoder() -> TileEncoder: @pytest.mark.parametrize("max_bag_size", [1, 7]) @pytest.mark.parametrize("pool_hidden_dim", [1, 5]) @pytest.mark.parametrize("pool_out_dim", [1, 6]) +@pytest.mark.parametrize("dropout_rate", [None, 0.5]) def test_lightningmodule( n_classes: int, pooling_layer: Callable[[int, int, int], nn.Module], @@ -49,6 +50,7 @@ def test_lightningmodule( max_bag_size: int, pool_hidden_dim: int, pool_out_dim: int, + dropout_rate: Optional[float], ) -> None: assert n_classes > 0 @@ -62,6 +64,7 @@ def test_lightningmodule( pooling_layer=pooling_layer, pool_hidden_dim=pool_hidden_dim, pool_out_dim=pool_out_dim, + dropout_rate=dropout_rate, ) bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim]) From 4231a71f3969001b100e5958bffe27ca51e6f9a0 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 4 Feb 2022 14:52:33 +0000 Subject: [PATCH 06/15] Add subsampling transform --- .../ML/Histopathology/models/transforms.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/InnerEye/ML/Histopathology/models/transforms.py b/InnerEye/ML/Histopathology/models/transforms.py index 1d61fdcb3..ae21bad85 100644 --- a/InnerEye/ML/Histopathology/models/transforms.py +++ b/InnerEye/ML/Histopathology/models/transforms.py @@ -10,7 +10,7 @@ import numpy as np import PIL from monai.config.type_definitions import KeysCollection -from monai.transforms.transform import MapTransform +from monai.transforms.transform import MapTransform, Randomizable from torchvision.transforms.functional import to_tensor from InnerEye.ML.Histopathology.models.encoders import TileEncoder @@ -128,3 +128,33 @@ def __call__(self, data: Mapping) -> Mapping: for key in self.key_iterator(out_data): out_data[key] = self._encode_tiles(data[key]) return out_data + + +def take_indices(data: Sequence, indices: np.ndarray) -> Sequence: + if isinstance(data, (np.ndarray, torch.Tensor)): + return data[indices] + elif isinstance(data, Sequence): + return [data[i] for i in indices] + else: + raise ValueError(f"Data of type {type(data)} is not indexable") + + +class Subsampled(MapTransform, Randomizable): + """Dictionary transform to randomly subsample the data down to a fixed maximum length""" + + def __init__(self, keys: KeysCollection, max_size: int) -> None: + super().__init__(keys, allow_missing_keys=False) + self.max_size = max_size + self._indices: np.ndarray + + def randomize(self, total_size: int) -> None: + subsample_size = min(self.max_size, total_size) + self._indices = self.R.choice(total_size, size=subsample_size) + + def __call__(self, data: Mapping) -> Mapping: + out_data = dict(data) # create shallow copy + size = len(data[self.keys[0]]) + self.randomize(size) + for key in self.key_iterator(out_data): + out_data[key] = take_indices(data[key], self._indices) + return out_data From 248a864a21e19e8ba026adb09d8f25c40db9b78e Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 4 Feb 2022 15:08:43 +0000 Subject: [PATCH 07/15] Add option to allow_missing_keys for Subsampled --- InnerEye/ML/Histopathology/models/transforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/transforms.py b/InnerEye/ML/Histopathology/models/transforms.py index ae21bad85..9e9927f9b 100644 --- a/InnerEye/ML/Histopathology/models/transforms.py +++ b/InnerEye/ML/Histopathology/models/transforms.py @@ -142,8 +142,9 @@ def take_indices(data: Sequence, indices: np.ndarray) -> Sequence: class Subsampled(MapTransform, Randomizable): """Dictionary transform to randomly subsample the data down to a fixed maximum length""" - def __init__(self, keys: KeysCollection, max_size: int) -> None: - super().__init__(keys, allow_missing_keys=False) + def __init__(self, keys: KeysCollection, max_size: int, + allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys=allow_missing_keys) self.max_size = max_size self._indices: np.ndarray From 46ff8ce5f468f1de5f415e7413610f1905710100 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 4 Feb 2022 15:50:59 +0000 Subject: [PATCH 08/15] Add dropout param to BaseMIL --- InnerEye/ML/configs/histo_configs/classification/BaseMIL.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 4328de36b..5e5ee1650 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -8,7 +8,7 @@ their datamodules and configure experiment-specific parameters. """ from pathlib import Path -from typing import Type # noqa +from typing import Optional, Type # noqa import param from torch import nn @@ -27,6 +27,7 @@ class BaseMIL(LightningContainer): # Model parameters: pooling_type: str = param.String(doc="Name of the pooling layer class to use.") + dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.") # l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass # Encoder parameters: @@ -98,6 +99,7 @@ def create_model(self) -> DeepMILModule: label_column=self.data_module.train_dataset.LABEL_COLUMN, n_classes=self.data_module.train_dataset.N_CLASSES, pooling_layer=self.get_pooling_layer(), + dropout_rate=self.dropout_rate, class_weights=self.data_module.class_weights, l_rate=self.l_rate, weight_decay=self.weight_decay, From 51f932b4221f1fb4958d42edc97b97916cd654f0 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 4 Feb 2022 18:16:28 +0000 Subject: [PATCH 09/15] Add docstring and tests for Subsampled --- .../ML/Histopathology/models/transforms.py | 10 +++- .../histopathology/models/test_transforms.py | 57 ++++++++++++++++++- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/transforms.py b/InnerEye/ML/Histopathology/models/transforms.py index 9e9927f9b..6e9b5aa79 100644 --- a/InnerEye/ML/Histopathology/models/transforms.py +++ b/InnerEye/ML/Histopathology/models/transforms.py @@ -92,7 +92,7 @@ def __init__(self, allow_missing_keys: bool = False, chunk_size: int = 0) -> None: """ - :param keys: Key(s) for the image path(s) in the input dictionary. + :param keys: Key(s) for the image tensor(s) in the input dictionary. :param encoder: The tile encoder to use for feature extraction. :param allow_missing_keys: If `False` (default), raises an exception when an input dictionary is missing any of the specified keys. @@ -144,6 +144,14 @@ class Subsampled(MapTransform, Randomizable): def __init__(self, keys: KeysCollection, max_size: int, allow_missing_keys: bool = False) -> None: + """ + :param keys: Key(s) for all batch elements that must be subsampled. + :param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at + random down to `max_size` along their first dimension. If shorter, the elements are merely + shuffled. + :param allow_missing_keys: If `False` (default), raises an exception when an input + dictionary is missing any of the specified keys. + """ super().__init__(keys, allow_missing_keys=allow_missing_keys) self.max_size = max_size self._indices: np.ndarray diff --git a/Tests/ML/histopathology/models/test_transforms.py b/Tests/ML/histopathology/models/test_transforms.py index 17aa8c60b..6852aaaf0 100644 --- a/Tests/ML/histopathology/models/test_transforms.py +++ b/Tests/ML/histopathology/models/test_transforms.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Callable, Sequence, Union +import numpy as np import pytest import torch from monai.data.dataset import CacheDataset, Dataset, PersistentDataset @@ -19,7 +20,7 @@ from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder -from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd +from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled from Tests.ML.util import assert_dicts_equal @@ -153,3 +154,57 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None: bagged_subset, transform=transform, cache_subdir="TCGA-CRCk_embed_cache") + + +@pytest.mark.parametrize('include_non_indexable', [True, False]) +@pytest.mark.parametrize('allow_missing_keys', [True, False]) +def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None: + batch_size = 5 + max_size = batch_size // 2 + data = { + 'array_1d': np.random.randn(batch_size), + 'array_2d': np.random.randn(batch_size, 4), + 'tensor_1d': torch.randn(batch_size), + 'tensor_2d': torch.randn(batch_size, 4), + 'list': torch.randn(batch_size).tolist(), + 'indices': list(range(batch_size)), + 'non-indexable': 42, + } + + keys_to_subsample = list(data.keys()) + if not include_non_indexable: + keys_to_subsample.remove('non-indexable') + keys_to_subsample.append('missing-key') + + subsampling = Subsampled(keys_to_subsample, max_size=max_size, + allow_missing_keys=allow_missing_keys) + + if include_non_indexable: + with pytest.raises(ValueError): + sub_data = subsampling(data) + return + elif not allow_missing_keys: + with pytest.raises(KeyError): + sub_data = subsampling(data) + return + else: + sub_data = subsampling(data) + + assert set(sub_data.keys()) == set(data.keys()) + + # Check lenghts before and after subsampling + for key in keys_to_subsample: + if key not in data: + continue # Skip missing keys + assert len(data[key]) == batch_size # type: ignore + assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore + + # Check contents of subsampled elements + for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']: + for idx, elem in zip(sub_data['indices'], sub_data[key]): + assert np.array_equal(elem, data[key][idx]) # type: ignore + + # Check that subsampling is random, i.e. subsequent calls shouldn't give identical results + sub_data2 = subsampling(data) + for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']: + assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore From f47c64742521dfca70035d0676d333a197f9d6af Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 4 Feb 2022 18:26:44 +0000 Subject: [PATCH 10/15] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3f195775..198017492 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ jobs that run in AzureML. - ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs - ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL - ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL +- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup. ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. From 2cf274401dee14faf7aaa01db57623145ed87d37 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Tue, 8 Feb 2022 15:16:18 +0000 Subject: [PATCH 11/15] Update to hi-ml with mean pooling --- hi-ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hi-ml b/hi-ml index a33c1ed07..6a510b96d 160000 --- a/hi-ml +++ b/hi-ml @@ -1 +1 @@ -Subproject commit a33c1ed07da8a42486dec9f939cd59eea4b2583e +Subproject commit 6a510b96d5abcbdf54e49e0c465436f97de58a56 From 8e2a2402b8018fb8d94d63ac9f22bc6f20cd85a8 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Mon, 14 Feb 2022 15:02:41 +0000 Subject: [PATCH 12/15] Enable mean pooling in DeepMIL --- InnerEye/ML/Histopathology/models/deepmil.py | 2 +- InnerEye/ML/configs/histo_configs/classification/BaseMIL.py | 4 +++- hi-ml | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 6052fbef0..30210f455 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -199,7 +199,7 @@ def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore with no_grad(): instance_features = self.encoder(instances) # N X L x 1 x 1 attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L - bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim) + bag_features = bag_features.view(1, -1) bag_logit = self.classifier_fn(bag_features) return bag_logit, attentions diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 5e5ee1650..f8d3ac7b1 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -14,7 +14,7 @@ from torch import nn from torchvision.models.resnet import resnet18 -from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer +from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer from InnerEye.ML.lightning_container import LightningContainer from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule @@ -88,6 +88,8 @@ def get_pooling_layer(self) -> Type[nn.Module]: return AttentionLayer elif self.pooling_type == GatedAttentionLayer.__name__: return GatedAttentionLayer + elif self.pooling_type == MeanPoolingLayer.__name__: + return MeanPoolingLayer else: raise ValueError(f"Unsupported pooling type: {self.pooling_type}") diff --git a/hi-ml b/hi-ml index 6a510b96d..eb1d160be 160000 --- a/hi-ml +++ b/hi-ml @@ -1 +1 @@ -Subproject commit 6a510b96d5abcbdf54e49e0c465436f97de58a56 +Subproject commit eb1d160be246e28e3adc2bd5038d70d3ac0528de From d053c0e8a2aba1598e72a6e086308aeda2fe32fd Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Mon, 14 Feb 2022 15:03:09 +0000 Subject: [PATCH 13/15] Add/refactor mean pooling tests --- .../ML/histopathology/models/test_deepmil.py | 54 ++++++++++++++++--- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/Tests/ML/histopathology/models/test_deepmil.py b/Tests/ML/histopathology/models/test_deepmil.py index 002ca0c72..e71e1e3d6 100644 --- a/Tests/ML/histopathology/models/test_deepmil.py +++ b/Tests/ML/histopathology/models/test_deepmil.py @@ -14,6 +14,7 @@ from health_ml.networks.layers.attention_layers import ( AttentionLayer, GatedAttentionLayer, + MeanPoolingLayer, ) from InnerEye.ML.lightning_container import LightningContainer @@ -37,14 +38,7 @@ def get_supervised_imagenet_encoder() -> TileEncoder: return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224) -@pytest.mark.parametrize("n_classes", [1, 3]) -@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer]) -@pytest.mark.parametrize("batch_size", [1, 15]) -@pytest.mark.parametrize("max_bag_size", [1, 7]) -@pytest.mark.parametrize("pool_hidden_dim", [1, 5]) -@pytest.mark.parametrize("pool_out_dim", [1, 6]) -@pytest.mark.parametrize("dropout_rate", [None, 0.5]) -def test_lightningmodule( +def _test_lightningmodule( n_classes: int, pooling_layer: Callable[[int, int, int], nn.Module], batch_size: int, @@ -119,6 +113,50 @@ def test_lightningmodule( assert torch.all(score <= 1) +@pytest.mark.parametrize("n_classes", [1, 3]) +@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer]) +@pytest.mark.parametrize("batch_size", [1, 15]) +@pytest.mark.parametrize("max_bag_size", [1, 7]) +@pytest.mark.parametrize("pool_hidden_dim", [1, 5]) +@pytest.mark.parametrize("pool_out_dim", [1, 6]) +@pytest.mark.parametrize("dropout_rate", [None, 0.5]) +def test_lightningmodule_attention( + n_classes: int, + pooling_layer: Callable[[int, int, int], nn.Module], + batch_size: int, + max_bag_size: int, + pool_hidden_dim: int, + pool_out_dim: int, + dropout_rate: Optional[float], +) -> None: + _test_lightningmodule(n_classes=n_classes, + pooling_layer=pooling_layer, + batch_size=batch_size, + max_bag_size=max_bag_size, + pool_hidden_dim=pool_hidden_dim, + pool_out_dim=pool_out_dim, + dropout_rate=dropout_rate) + + +@pytest.mark.parametrize("n_classes", [1, 3]) +@pytest.mark.parametrize("batch_size", [1, 15]) +@pytest.mark.parametrize("max_bag_size", [1, 7]) +@pytest.mark.parametrize("dropout_rate", [None, 0.5]) +def test_lightningmodule_mean_pooling( + n_classes: int, + batch_size: int, + max_bag_size: int, + dropout_rate: Optional[float], +) -> None: + _test_lightningmodule(n_classes=n_classes, + pooling_layer=MeanPoolingLayer, + batch_size=batch_size, + max_bag_size=max_bag_size, + pool_hidden_dim=1, + pool_out_dim=1, + dropout_rate=dropout_rate) + + def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict: device = "cuda" if use_gpu else "cpu" return { From ce2f478fe89c1e9168d883e479b589f8477ee573 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 18 Feb 2022 14:59:15 +0000 Subject: [PATCH 14/15] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa0c76e0a..5bfcd107f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ jobs that run in AzureML. - ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL - ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup. - ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task. +- ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling. ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. From 12a814f012b9c8b583250b3592179524aa52fd38 Mon Sep 17 00:00:00 2001 From: Daniel Coelho de Castro Date: Fri, 18 Feb 2022 15:16:16 +0000 Subject: [PATCH 15/15] Update to latest hi-ml with mean pooling --- hi-ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hi-ml b/hi-ml index eb1d160be..2bc397b47 160000 --- a/hi-ml +++ b/hi-ml @@ -1 +1 @@ -Subproject commit eb1d160be246e28e3adc2bd5038d70d3ac0528de +Subproject commit 2bc397b4707b56fecca624ce81e6883e0170b24b