Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
845910a
Add dropout to DeepMILModule
Feb 1, 2022
34c2089
Fix feature extractor setup for torchvision models
Feb 1, 2022
31961cb
Refactor DeepMIL dropout into classifier
dccastro Feb 4, 2022
8de20ec
Refactor and add tests for feature extractor setup
dccastro Feb 4, 2022
2541eee
Add dropout to DeepMIL tests
dccastro Feb 4, 2022
b192888
Merge remote-tracking branch 'origin/main' into dacoelh/deepmil_dropout
dccastro Feb 4, 2022
4231a71
Add subsampling transform
dccastro Feb 4, 2022
248a864
Add option to allow_missing_keys for Subsampled
dccastro Feb 4, 2022
46ff8ce
Add dropout param to BaseMIL
dccastro Feb 4, 2022
26c62a3
Merge branch 'dacoelh/deepmil_dropout' into dacoelh/subsample_tiles
dccastro Feb 4, 2022
51f932b
Add docstring and tests for Subsampled
dccastro Feb 4, 2022
eb1a08c
Merge remote-tracking branch 'origin/main' into dacoelh/deepmil_dropout
dccastro Feb 4, 2022
f47c647
Update changelog
dccastro Feb 4, 2022
34f5378
Merge branch 'dacoelh/deepmil_dropout' into dacoelh/subsample_tiles
dccastro Feb 4, 2022
055d54c
Merge remote-tracking branch 'origin/main' into dacoelh/subsample_tiles
dccastro Feb 7, 2022
2cf2744
Update to hi-ml with mean pooling
dccastro Feb 8, 2022
8e2a240
Enable mean pooling in DeepMIL
dccastro Feb 14, 2022
d053c0e
Add/refactor mean pooling tests
dccastro Feb 14, 2022
6952188
Merge remote-tracking branch 'origin/main' into dacoelh/subsample_tiles
dccastro Feb 14, 2022
ce2f478
Update changelog
dccastro Feb 18, 2022
12a814f
Update to latest hi-ml with mean pooling
dccastro Feb 18, 2022
8fb1bbb
Merge remote-tracking branch 'origin/main' into dacoelh/subsample_tiles
dccastro Feb 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs that run in AzureML.
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
- ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling.

### Changed
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3.
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
with set_grad_enabled(self.is_finetune):
instance_features = self.encoder(instances) # N X L x 1 x 1
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)
bag_features = bag_features.view(1, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this changed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is a more robust way to reshape the outputs, without relying on num_encoding and pool_out_dim provided by the encoder and pooling components. In particular, MeanPoolingLayer ignores all arguments passed to it, so this operation would have failed if we expected a different shape here.

bag_logit = self.classifier_fn(bag_features)
return bag_logit, attentions

Expand Down
43 changes: 41 additions & 2 deletions InnerEye/ML/Histopathology/models/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import PIL
from monai.config.type_definitions import KeysCollection
from monai.transforms.transform import MapTransform
from monai.transforms.transform import MapTransform, Randomizable
from torchvision.transforms.functional import to_tensor

from InnerEye.ML.Histopathology.models.encoders import TileEncoder
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self,
allow_missing_keys: bool = False,
chunk_size: int = 0) -> None:
"""
:param keys: Key(s) for the image path(s) in the input dictionary.
:param keys: Key(s) for the image tensor(s) in the input dictionary.
:param encoder: The tile encoder to use for feature extraction.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
Expand Down Expand Up @@ -128,3 +128,42 @@ def __call__(self, data: Mapping) -> Mapping:
for key in self.key_iterator(out_data):
out_data[key] = self._encode_tiles(data[key])
return out_data


def take_indices(data: Sequence, indices: np.ndarray) -> Sequence:
if isinstance(data, (np.ndarray, torch.Tensor)):
return data[indices]
elif isinstance(data, Sequence):
return [data[i] for i in indices]
else:
raise ValueError(f"Data of type {type(data)} is not indexable")


class Subsampled(MapTransform, Randomizable):
"""Dictionary transform to randomly subsample the data down to a fixed maximum length"""

def __init__(self, keys: KeysCollection, max_size: int,
allow_missing_keys: bool = False) -> None:
"""
:param keys: Key(s) for all batch elements that must be subsampled.
:param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at
random down to `max_size` along their first dimension. If shorter, the elements are merely
shuffled.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
"""
super().__init__(keys, allow_missing_keys=allow_missing_keys)
self.max_size = max_size
self._indices: np.ndarray

def randomize(self, total_size: int) -> None:
subsample_size = min(self.max_size, total_size)
self._indices = self.R.choice(total_size, size=subsample_size)

def __call__(self, data: Mapping) -> Mapping:
out_data = dict(data) # create shallow copy
size = len(data[self.keys[0]])
self.randomize(size)
for key in self.key_iterator(out_data):
out_data[key] = take_indices(data[key], self._indices)
return out_data
4 changes: 3 additions & 1 deletion InnerEye/ML/configs/histo_configs/classification/BaseMIL.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn
from torchvision.models.resnet import resnet18

from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
Expand Down Expand Up @@ -90,6 +90,8 @@ def get_pooling_layer(self) -> Type[nn.Module]:
return AttentionLayer
elif self.pooling_type == GatedAttentionLayer.__name__:
return GatedAttentionLayer
elif self.pooling_type == MeanPoolingLayer.__name__:
return MeanPoolingLayer
else:
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")

Expand Down
54 changes: 46 additions & 8 deletions Tests/ML/histopathology/models/test_deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from health_ml.networks.layers.attention_layers import (
AttentionLayer,
GatedAttentionLayer,
MeanPoolingLayer,
)

from InnerEye.ML.lightning_container import LightningContainer
Expand All @@ -37,14 +38,7 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)


@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule(
def _test_lightningmodule(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
Expand Down Expand Up @@ -119,6 +113,50 @@ def test_lightningmodule(
assert torch.all(score <= 1)


@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_attention(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=pooling_layer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate)


@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_mean_pooling(
n_classes: int,
batch_size: int,
max_bag_size: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=MeanPoolingLayer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=1,
pool_out_dim=1,
dropout_rate=dropout_rate)


def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
device = "cuda" if use_gpu else "cpu"
return {
Expand Down
57 changes: 56 additions & 1 deletion Tests/ML/histopathology/models/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Callable, Sequence, Union

import numpy as np
import pytest
import torch
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
Expand All @@ -19,7 +20,7 @@
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled
from Tests.ML.util import assert_dicts_equal


Expand Down Expand Up @@ -153,3 +154,57 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
bagged_subset,
transform=transform,
cache_subdir="TCGA-CRCk_embed_cache")


@pytest.mark.parametrize('include_non_indexable', [True, False])
@pytest.mark.parametrize('allow_missing_keys', [True, False])
def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None:
batch_size = 5
max_size = batch_size // 2
data = {
'array_1d': np.random.randn(batch_size),
'array_2d': np.random.randn(batch_size, 4),
'tensor_1d': torch.randn(batch_size),
'tensor_2d': torch.randn(batch_size, 4),
'list': torch.randn(batch_size).tolist(),
'indices': list(range(batch_size)),
'non-indexable': 42,
}

keys_to_subsample = list(data.keys())
if not include_non_indexable:
keys_to_subsample.remove('non-indexable')
keys_to_subsample.append('missing-key')

subsampling = Subsampled(keys_to_subsample, max_size=max_size,
allow_missing_keys=allow_missing_keys)

if include_non_indexable:
with pytest.raises(ValueError):
sub_data = subsampling(data)
return
elif not allow_missing_keys:
with pytest.raises(KeyError):
sub_data = subsampling(data)
return
else:
sub_data = subsampling(data)

assert set(sub_data.keys()) == set(data.keys())

# Check lenghts before and after subsampling
for key in keys_to_subsample:
if key not in data:
continue # Skip missing keys
assert len(data[key]) == batch_size # type: ignore
assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore

# Check contents of subsampled elements
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
for idx, elem in zip(sub_data['indices'], sub_data[key]):
assert np.array_equal(elem, data[key][idx]) # type: ignore

# Check that subsampling is random, i.e. subsequent calls shouldn't give identical results
sub_data2 = subsampling(data)
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore
2 changes: 1 addition & 1 deletion hi-ml
Submodule hi-ml updated 164 files