diff --git a/CHANGELOG.md b/CHANGELOG.md index c3f195775..198017492 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ jobs that run in AzureML. - ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs - ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL - ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL +- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup. ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 8e001e080..6052fbef0 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -46,6 +46,7 @@ def __init__(self, pooling_layer: Callable[[int, int, int], nn.Module], pool_hidden_dim: int = 128, pool_out_dim: int = 1, + dropout_rate: Optional[float] = None, class_weights: Optional[Tensor] = None, l_rate: float = 5e-4, weight_decay: float = 1e-4, @@ -64,6 +65,7 @@ def __init__(self, `torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions. :param pool_hidden_dim: Hidden dimension of pooling layer (default=128). :param pool_out_dim: Output dimension of pooling layer (default=1). + :param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default). :param class_weights: Tensor containing class weights (default=None). :param l_rate: Optimiser learning rate. :param weight_decay: Weight decay parameter for L2 regularisation. @@ -82,6 +84,7 @@ def __init__(self, self.pool_hidden_dim = pool_hidden_dim self.pool_out_dim = pool_out_dim self.pooling_layer = pooling_layer + self.dropout_rate = dropout_rate self.class_weights = class_weights self.encoder = encoder self.num_encoding = self.encoder.num_encoding @@ -130,8 +133,14 @@ def get_pooling(self) -> Tuple[Callable, int]: return pooling_layer, num_features def get_classifier(self) -> Callable: - return nn.Linear(in_features=self.num_pooling, - out_features=self.n_classes) + classifier_layer = nn.Linear(in_features=self.num_pooling, + out_features=self.n_classes) + if self.dropout_rate is None: + return classifier_layer + elif 0 <= self.dropout_rate < 1: + return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer) + else: + raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}") def get_loss(self) -> Callable: if self.n_classes > 1: @@ -186,13 +195,13 @@ def log_metrics(self, else: log_on_epoch(self, f'{stage}/{metric_name}', metric_object) - def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore + def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore with no_grad(): - H = self.encoder(images) # N X L x 1 x 1 - A, M = self.aggregation_fn(H) # A: K x N | M: K x L - M = M.view(-1, self.num_encoding * self.pool_out_dim) - Y_prob = self.classifier_fn(M) - return Y_prob, A + instance_features = self.encoder(instances) # N X L x 1 x 1 + attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L + bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim) + bag_logit = self.classifier_fn(bag_features) + return bag_logit, attentions def configure_optimizers(self) -> optim.Optimizer: return optim.Adam(self.parameters(), lr=self.l_rate, weight_decay=self.weight_decay, diff --git a/InnerEye/ML/Histopathology/utils/layer_utils.py b/InnerEye/ML/Histopathology/utils/layer_utils.py index d3b88d3c0..842193563 100644 --- a/InnerEye/ML/Histopathology/utils/layer_utils.py +++ b/InnerEye/ML/Histopathology/utils/layer_utils.py @@ -3,9 +3,9 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Callable, Tuple +from typing import Tuple -from torch import as_tensor, device, nn, prod, rand +from torch import as_tensor, device, nn, no_grad, prod, rand from torch.hub import load_state_dict_from_url from torchvision.transforms import Normalize @@ -15,15 +15,23 @@ def get_imagenet_preprocessing() -> nn.Module: def setup_feature_extractor(pretrained_model: nn.Module, - input_dim: Tuple[int, int, int]) -> Tuple[Callable, int]: - layers = list(pretrained_model.children())[:-1] - layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps - feature_extractor = nn.Sequential(*layers) + input_dim: Tuple[int, int, int]) -> Tuple[nn.Module, int]: + try: + # Attempt to auto-detect final classification layer: + num_features: int = pretrained_model.fc.in_features # type: ignore + pretrained_model.fc = nn.Flatten() + feature_extractor = pretrained_model + except AttributeError: + # Otherwise fallback to sequence of child modules: + layers = list(pretrained_model.children())[:-1] + layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps + feature_extractor = nn.Sequential(*layers) + with no_grad(): + feature_shape = feature_extractor(rand(1, *input_dim)).shape + num_features = int(prod(as_tensor(feature_shape)).item()) # fix weights, no fine-tuning for param in feature_extractor.parameters(): param.requires_grad = False - feature_shape = feature_extractor(rand(1, *input_dim)).shape - num_features = int(prod(as_tensor(feature_shape)).item()) return feature_extractor, num_features diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 4328de36b..5e5ee1650 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -8,7 +8,7 @@ their datamodules and configure experiment-specific parameters. """ from pathlib import Path -from typing import Type # noqa +from typing import Optional, Type # noqa import param from torch import nn @@ -27,6 +27,7 @@ class BaseMIL(LightningContainer): # Model parameters: pooling_type: str = param.String(doc="Name of the pooling layer class to use.") + dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.") # l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass # Encoder parameters: @@ -98,6 +99,7 @@ def create_model(self) -> DeepMILModule: label_column=self.data_module.train_dataset.LABEL_COLUMN, n_classes=self.data_module.train_dataset.N_CLASSES, pooling_layer=self.get_pooling_layer(), + dropout_rate=self.dropout_rate, class_weights=self.data_module.class_weights, l_rate=self.l_rate, weight_decay=self.weight_decay, diff --git a/Tests/ML/histopathology/models/test_deepmil.py b/Tests/ML/histopathology/models/test_deepmil.py index 43b5e517b..002ca0c72 100644 --- a/Tests/ML/histopathology/models/test_deepmil.py +++ b/Tests/ML/histopathology/models/test_deepmil.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------------------ import os -from typing import Callable, Dict, List, Type # noqa +from typing import Callable, Dict, List, Optional, Type # noqa import pytest import torch @@ -39,10 +39,11 @@ def get_supervised_imagenet_encoder() -> TileEncoder: @pytest.mark.parametrize("n_classes", [1, 3]) @pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer]) -@pytest.mark.parametrize("batch_size", [1, 2]) -@pytest.mark.parametrize("max_bag_size", [1, 3]) -@pytest.mark.parametrize("pool_hidden_dim", [1, 4]) -@pytest.mark.parametrize("pool_out_dim", [1, 5]) +@pytest.mark.parametrize("batch_size", [1, 15]) +@pytest.mark.parametrize("max_bag_size", [1, 7]) +@pytest.mark.parametrize("pool_hidden_dim", [1, 5]) +@pytest.mark.parametrize("pool_out_dim", [1, 6]) +@pytest.mark.parametrize("dropout_rate", [None, 0.5]) def test_lightningmodule( n_classes: int, pooling_layer: Callable[[int, int, int], nn.Module], @@ -50,6 +51,7 @@ def test_lightningmodule( max_bag_size: int, pool_hidden_dim: int, pool_out_dim: int, + dropout_rate: Optional[float], ) -> None: assert n_classes > 0 @@ -63,6 +65,7 @@ def test_lightningmodule( pooling_layer=pooling_layer, pool_hidden_dim=pool_hidden_dim, pool_out_dim=pool_out_dim, + dropout_rate=dropout_rate, ) bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim]) diff --git a/Tests/ML/histopathology/models/test_encoders.py b/Tests/ML/histopathology/models/test_encoders.py index a9ad82864..29fefc31f 100644 --- a/Tests/ML/histopathology/models/test_encoders.py +++ b/Tests/ML/histopathology/models/test_encoders.py @@ -3,43 +3,68 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Callable +from typing import Callable, Tuple +import numpy as np import pytest from torch import Tensor, float32, nn, rand from torchvision.models import resnet18 from InnerEye.ML.Histopathology.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder, ImageNetSimCLREncoder) +from InnerEye.ML.Histopathology.utils.layer_utils import setup_feature_extractor + +TILE_SIZE = 224 +INPUT_DIMS = (3, TILE_SIZE, TILE_SIZE) def get_supervised_imagenet_encoder() -> TileEncoder: - return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224) + return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=TILE_SIZE) def get_simclr_imagenet_encoder() -> TileEncoder: - return ImageNetSimCLREncoder(tile_size=224) + return ImageNetSimCLREncoder(tile_size=TILE_SIZE) def get_histo_ssl_encoder() -> TileEncoder: - return HistoSSLEncoder(tile_size=224) - - -@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder, - get_simclr_imagenet_encoder, - get_histo_ssl_encoder]) -def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None: - batch_size = 10 + return HistoSSLEncoder(tile_size=TILE_SIZE) - encoder = create_encoder_fn() +def _test_encoder(encoder: nn.Module, input_dims: Tuple[int, ...], output_dim: int, + batch_size: int = 5) -> None: if isinstance(encoder, nn.Module): for param_name, param in encoder.named_parameters(): assert not param.requires_grad, \ f"Feature extractor has unfrozen parameters: {param_name}" - images = rand(batch_size, *encoder.input_dim, dtype=float32) + images = rand(batch_size, *input_dims, dtype=float32) features = encoder(images) assert isinstance(features, Tensor) - assert features.shape == (batch_size, encoder.num_encoding) + assert features.shape == (batch_size, output_dim) + + +@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder, + get_simclr_imagenet_encoder, + get_histo_ssl_encoder]) +def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None: + encoder = create_encoder_fn() + _test_encoder(encoder, input_dims=encoder.input_dim, output_dim=encoder.num_encoding) + + +def _dummy_classifier() -> nn.Module: + input_size = np.prod(INPUT_DIMS) + hidden_dim = 10 + return nn.Sequential( + nn.Flatten(), + nn.Linear(input_size, hidden_dim), + nn.Tanh(), + nn.Linear(hidden_dim, 1) + ) + + +@pytest.mark.parametrize('create_classifier_fn', [resnet18, _dummy_classifier]) +def test_setup_feature_extractor(create_classifier_fn: Callable[[], nn.Module]) -> None: + classifier = create_classifier_fn() + encoder, num_features = setup_feature_extractor(classifier, INPUT_DIMS) + _test_encoder(encoder, input_dims=INPUT_DIMS, output_dim=num_features)