Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
3729a8f
fixing PandaInnereSSLMIL
vale-salvatelli Dec 20, 2021
2d927a3
updating checkpoint downloader
vale-salvatelli Dec 21, 2021
b8fd6c8
checkpoint for inference found
vale-salvatelli Dec 21, 2021
7c8e7d2
fixing outputs paths
vale-salvatelli Dec 21, 2021
39829b8
fixing test
vale-salvatelli Jan 11, 2022
6b9ab94
update changelog
vale-salvatelli Jan 11, 2022
137dc99
Merge branch 'main' into vsalva/use_panda_checkpoint
vale-salvatelli Jan 11, 2022
c932ef7
implementing PR feedback, thanks Anton and Daniel
vale-salvatelli Jan 11, 2022
7d8d86c
typo
vale-salvatelli Jan 11, 2022
8dc5a48
updating to latest CRCk checkpoint, new augmentations
vale-salvatelli Jan 11, 2022
9a010c6
moving checkpoint ids to file
vale-salvatelli Jan 11, 2022
bc21f96
Merge branch 'main' into vsalva/use_panda_checkpoint
vale-salvatelli Jan 11, 2022
62e3831
first draft
vale-salvatelli Jan 13, 2022
ffa0dc2
extend test
vale-salvatelli Jan 13, 2022
eff2349
all works with large batch size
vale-salvatelli Jan 19, 2022
5290945
making cpu memory an option
vale-salvatelli Jan 19, 2022
80c6447
clean up chunk size parameter
vale-salvatelli Jan 19, 2022
97ae348
add TODO
vale-salvatelli Jan 19, 2022
da8b44f
fixing conflicts
vale-salvatelli Jan 19, 2022
cd94e9f
update changelog
vale-salvatelli Jan 19, 2022
2498652
making clarity on cachemode vs precache mode
vale-salvatelli Jan 19, 2022
a5ce7f8
fix typo
vale-salvatelli Jan 19, 2022
505635c
update test after refactoring
vale-salvatelli Jan 19, 2022
c891dde
update test after refactoring
vale-salvatelli Jan 19, 2022
d1d7b27
remove typo in tests
vale-salvatelli Jan 20, 2022
e59fef2
change optional type
vale-salvatelli Jan 20, 2022
7644da5
Update InnerEye/ML/configs/histo_configs/classification/DeepSMILEPand…
vale-salvatelli Jan 20, 2022
7cc9c8b
Update InnerEye/ML/configs/histo_configs/classification/DeepSMILEPand…
vale-salvatelli Jan 20, 2022
4550982
Update InnerEye/ML/configs/histo_configs/classification/BaseMIL.py
vale-salvatelli Jan 20, 2022
7461666
change load_image function
vale-salvatelli Jan 20, 2022
953244d
Merge branch 'vsalva/chunk_encoding' of https://github.com/microsoft/…
vale-salvatelli Jan 20, 2022
3e06759
revert some changes to avoid inconsistencies in type
vale-salvatelli Jan 20, 2022
2249ac4
implement PR feedback
vale-salvatelli Jan 20, 2022
198ab9b
making realistic test cases in test_tile_id_coverage
vale-salvatelli Jan 20, 2022
afe1d14
minor fixes
vale-salvatelli Jan 20, 2022
c3e1eef
fix test and extend location cases
vale-salvatelli Jan 24, 2022
62354a6
remove generic to cuda
vale-salvatelli Jan 24, 2022
14c9746
fix naming error GPU
vale-salvatelli Jan 24, 2022
91dfabc
trying adding some reproducibility to failing test
vale-salvatelli Jan 24, 2022
d9757c9
Merge branch 'main' into vsalva/chunk_encoding
vale-salvatelli Jan 24, 2022
ffb9535
Merge branch 'main' into vsalva/chunk_encoding
vale-salvatelli Jan 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ loss.
GPU utilization via Lightning's `GpuStatsMonitor`, switch `monitor_loading` to check batch loading times via
`BatchTimeCallback`, and `pl_profiler` to turn on the Lightning profiler (`simple`, `advanced`, or `pytorch`)
- ([#544](https://github.com/microsoft/InnerEye-DeepLearning/pull/544)) Add documentation for segmentation model evaluation.
- ([#637](https://github.com/microsoft/InnerEye-DeepLearning/pull/637)) Add option to encode in chunks and to load pre-cached dataset in CPU or GPU in the histo pipeline.
- ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference
module on test data with partial ground truth files. (Also [522](https://github.com/microsoft/InnerEye-DeepLearning/pull/522).)
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference.
Expand Down Expand Up @@ -133,7 +134,7 @@ in inference-only runs when using lightning containers.

### Deprecated

- ([#633](https://github.com/microsoft/InnerEye-DeepLearning/pull/633)) Model fields `recovery_checkpoint_save_interval` and `recovery_checkpoints_save_last_k` have been retired.
- ([#633](https://github.com/microsoft/InnerEye-DeepLearning/pull/633)) Model fields `recovery_checkpoint_save_interval` and `recovery_checkpoints_save_last_k` have been retired.
Recovery checkpoint handling is now controlled by `autosave_every_n_val_epochs`.


Expand Down
51 changes: 34 additions & 17 deletions InnerEye/ML/Histopathology/datamodules/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

import pickle
import torch
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, Union
Expand All @@ -23,13 +23,17 @@ class CacheMode(Enum):
MEMORY = 'memory'
DISK = 'disk'


class CacheLocation(Enum):
NONE = 'none'
CPU = 'cpu'
SAME = 'same'
class TilesDataModule(LightningDataModule):
"""Base class to load the tiles of a dataset as train, val, test sets"""

def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1,
seed: Optional[int] = None, transform: Optional[Callable] = None,
cache_mode: CacheMode = CacheMode.NONE, save_precache: bool = False,
cache_mode: CacheMode = CacheMode.NONE,
precache_location: CacheLocation = CacheLocation.NONE,
cache_dir: Optional[Path] = None,
number_of_cross_validation_splits: int = 0,
cross_validation_split_index: int = 0) -> None:
Expand All @@ -46,19 +50,24 @@ def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1,
:param cache_mode: The type of caching to perform, i.e. whether the results of all
transforms up to the first randomised one should be computed only once and reused in
subsequent iterations:
- `MEMORY`: the entire transformed dataset is kept in memory for fastest access;
- `DISK`: each transformed sample is saved to disk and loaded on-demand;
- `NONE` (default): no caching is performed.
:param save_precache: Whether to pre-cache the entire transformed dataset upfront and save
- `MEMORY`: MONAI CacheDataset is used, the entire transformed dataset is kept in memory for fastest access;
- `DISK`: MONAI PersistentDataset is used, each transformed sample is saved to disk and loaded on-demand;
- `NONE` (default): standard MONAI dataset is used, no caching is performed.
:param precache_location: Whether to pre-cache the entire transformed dataset upfront and save
it to disk. This is done once in `prepare_data()` only on the local rank-0 process, so
multiple processes can afterwards access the same cache without contention in DDP settings.
multiple processes can afterwards access the same cache without contention in DDP settings. This parameter also allow to
choose if the cache will be re-loaded into CPU or GPU memory:
- `NONE (default)`: no pre-cache is performed;
- `CPU`: each transformed sample is saved to disk and, if cache_mode is `MEMORY`, reloaded into CPU;
- `SAME`: each transformed sample is saved to disk and, if cache_mode is `MEMORY`, reloaded on the same device it was saved from;
If cache_mode is `DISK` precache_location `CPU` and `GPU` are equivalent.
:param cache_dir: The directory onto which to cache data if caching is enabled.
:param number_of_cross_validation_splits: Number of folds to perform.
:param cross_validation_split_index: Index of the cross validation split to be performed.
"""
if save_precache and cache_mode is CacheMode.NONE:
if precache_location is not CacheLocation.NONE and cache_mode is CacheMode.NONE:
raise ValueError("Can only pre-cache if caching is enabled")
if save_precache and cache_dir is None:
if precache_location is not CacheLocation.NONE and cache_dir is None:
raise ValueError("A cache directory is required for pre-caching")
if cache_mode is CacheMode.DISK and cache_dir is None:
raise ValueError("A cache directory is required for on-disk caching")
Expand All @@ -68,7 +77,7 @@ def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1,
self.max_bag_size = max_bag_size
self.transform = transform
self.cache_mode = cache_mode
self.save_precache = save_precache
self.precache_location = precache_location
self.cache_dir = cache_dir
self.batch_size = batch_size
self.number_of_cross_validation_splits = number_of_cross_validation_splits
Expand All @@ -82,22 +91,29 @@ def get_splits(self) -> Tuple[TilesDataset, TilesDataset, TilesDataset]:
raise NotImplementedError

def prepare_data(self) -> None:
if self.save_precache:
if self.precache_location != CacheLocation.NONE:
self._load_dataset(self.train_dataset, stage='train', shuffle=True)
self._load_dataset(self.val_dataset, stage='val', shuffle=True)
self._load_dataset(self.test_dataset, stage='test', shuffle=True)

def _dataset_pickle_path(self, stage: str) -> Optional[Path]:
if self.cache_dir is None:
return None
return self.cache_dir / f"{stage}_dataset.pkl"
return self.cache_dir / f"{stage}_dataset.pt"

def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool) -> Dataset:
dataset_pickle_path = self._dataset_pickle_path(stage)

if dataset_pickle_path and dataset_pickle_path.exists():
if dataset_pickle_path and dataset_pickle_path.is_file():
if self.precache_location == CacheLocation.CPU:
memory_location = torch.device('cpu')
print(f"Loading dataset from {dataset_pickle_path} into {memory_location}")
else:
# by default torch.load will reload on the same device it was saved from
memory_location = None # type: ignore

with dataset_pickle_path.open('rb') as f:
return pickle.load(f)
return torch.load(f, map_location=memory_location)

generator = _create_generator(self.seed)
bag_dataset = BagDataset(tiles_dataset, # type: ignore
Expand All @@ -112,10 +128,11 @@ def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool)
transformed_bag_dataset = self._get_transformed_dataset(bag_dataset, transform) # type: ignore
generator.set_state(generator_state)

# Dataset is saved if cache_dir is True, regardless of CacheMode
if dataset_pickle_path:
dataset_pickle_path.parent.mkdir(parents=True, exist_ok=True)
with dataset_pickle_path.open('wb') as f:
pickle.dump(transformed_bag_dataset, f)
torch.save(transformed_bag_dataset, f)

return transformed_bag_dataset

Expand All @@ -125,7 +142,7 @@ def _get_transformed_dataset(self, base_dataset: BagDataset,
dataset = CacheDataset(base_dataset, transform, num_workers=1) # type: ignore
elif self.cache_mode is CacheMode.DISK:
dataset = PersistentDataset(base_dataset, transform, cache_dir=self.cache_dir) # type: ignore
if self.save_precache:
if self.precache_location != CacheLocation.NONE:
import tqdm # TODO: Make optional

for i in tqdm.trange(len(dataset), desc="Loading dataset"):
Expand Down
6 changes: 3 additions & 3 deletions InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self,
weight_decay: float = 1e-4,
adam_betas: Tuple[float, float] = (0.9, 0.99),
verbose: bool = False,
slide_dataset: SlidesDataset = None,
slide_dataset: SlidesDataset = None,
tile_size: int = 224,
level: int = 1) -> None:
"""
Expand Down Expand Up @@ -324,13 +324,13 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore

if self.slide_dataset is not None:
slide_dict = mi.first_true(self.slide_dataset, pred=lambda entry: entry[SlideKey.SLIDE_ID] == slide) # type: ignore
_ = load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore
_ = load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore
slide_image = slide_dict[SlideKey.IMAGE]
location_bbox = slide_dict[SlideKey.LOCATION]

fig = plot_slide(slide_image=slide_image, scale=1.0)
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_thumbnail.png'))
fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results,
fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results,
location_bbox=location_bbox, tile_size=self.tile_size, level=self.level)
self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_heatmap.png'))

Expand Down
24 changes: 21 additions & 3 deletions InnerEye/ML/Histopathology/models/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from pathlib import Path
from typing import Mapping, Sequence, Union

import PIL.Image
import torch
import numpy as np
import PIL
from monai.config.type_definitions import KeysCollection
from monai.transforms.transform import MapTransform
from torchvision.transforms.functional import to_tensor
Expand All @@ -19,7 +20,9 @@

def load_pil_image(image_path: PathOrString) -> PIL.Image.Image:
"""Load a PIL image in RGB format from the given path"""
return PIL.Image.open(image_path).convert('RGB')
with PIL.PngImagePlugin.PngImageFile(image_path) as pil_png:
image = np.asarray(pil_png)
return image


def load_image_as_tensor(image_path: PathOrString) -> torch.Tensor:
Expand Down Expand Up @@ -86,19 +89,34 @@ class EncodeTilesBatchd(MapTransform):
def __init__(self,
keys: KeysCollection,
encoder: TileEncoder,
allow_missing_keys: bool = False) -> None:
allow_missing_keys: bool = False,
chunk_size: int = 0) -> None:
"""
:param keys: Key(s) for the image path(s) in the input dictionary.
:param encoder: The tile encoder to use for feature extraction.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
:param chunk_size: if > 0, extracts features in chunks of size chunk_size.
"""
super().__init__(keys, allow_missing_keys)
self.encoder = encoder
self.chunk_size = chunk_size

@torch.no_grad()
def _encode_tiles(self, images: torch.Tensor) -> torch.Tensor:
device = next(self.encoder.parameters()).device
if self.chunk_size > 0:
embeddings = []
chunks = torch.split(images, self.chunk_size)
# TODO parallelize encoding - keep metadata and images aligned
for chunk in chunks:
chunk_embeddings = self._encode_images(chunk, device)
embeddings.append(chunk_embeddings)
return torch.cat(embeddings)
else:
return self._encode_images(images, device)

def _encode_images(self, images: torch.Tensor, device: torch.device) -> torch.Tensor:
images = images.to(device)
embeddings = self.encoder(images)
del images
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/Histopathology/utils/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
from math import ceil
import numpy as np
import matplotlib.patches as patches
import matplotlib.patches as patches
import matplotlib.collections as collection

from InnerEye.ML.Histopathology.models.transforms import load_pil_image
Expand Down Expand Up @@ -117,7 +117,7 @@ def plot_slide(slide_image: np.ndarray, scale: float) -> plt.figure:
return fig


def plot_heatmap_overlay(slide: str,
def plot_heatmap_overlay(slide: str,
slide_image: np.ndarray,
results: Dict[str, List[Any]],
location_bbox: List[int],
Expand Down
12 changes: 8 additions & 4 deletions InnerEye/ML/configs/histo_configs/classification/BaseMIL.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
their datamodules and configure experiment-specific parameters.
"""
from pathlib import Path
from typing import Type
from typing import Type # noqa

import param
from torch import nn
Expand All @@ -17,7 +17,7 @@
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, TilesDataModule
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule
from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder,
ImageNetEncoder, ImageNetSimCLREncoder,
Expand All @@ -44,8 +44,12 @@ class BaseMIL(LightningContainer):
cache_mode: CacheMode = param.ClassSelector(default=CacheMode.MEMORY, class_=CacheMode,
doc="The type of caching to perform: "
"'memory' (default), 'disk', or 'none'.")
save_precache: bool = param.Boolean(True, doc="Whether to pre-cache the entire transformed "
"dataset upfront and save it to disk.")
precache_location: str = param.ClassSelector(default=CacheLocation.NONE, class_=CacheLocation,
doc="Whether to pre-cache the entire transformed dataset upfront "
"and save it to disk and if re-load in cpu or gpu. Options:"
"`none` (default),`cpu`, `gpu`")
encoding_chunk_size: int = param.Integer(0, doc="If > 0 performs encoding in chunks, by loading"
"enconding_chunk_size tiles per chunk")
# local_dataset (used as data module root_path) is declared in DatasetParams superclass

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __init__(self, **kwargs: Any) -> None:
default_kwargs = dict(
# declared in BaseMIL:
pooling_type=GatedAttentionLayer.__name__,
# average number of tiles is 56 for PANDA
encoding_chunk_size=60,

# declared in DatasetParams:
local_dataset=Path("/tmp/datasets/PANDA_tiles"),
azure_dataset_id="PANDA_tiles",
Expand All @@ -48,9 +51,11 @@ def __init__(self, **kwargs: Any) -> None:
# declared in TrainerParams:
num_epochs=200,
# use_mixed_precision = True,

# declared in WorkflowParams:
number_of_cross_validation_splits=5,
cross_validation_split_index=0,

# declared in OptimizerParams:
l_rate=5e-4,
weight_decay=1e-4,
Expand Down Expand Up @@ -101,7 +106,7 @@ def get_data_module(self) -> PandaTilesDataModule:
transform = Compose(
[
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size),
]
)
return PandaTilesDataModule(
Expand All @@ -110,7 +115,7 @@ def get_data_module(self) -> PandaTilesDataModule:
batch_size=self.batch_size,
transform=transform,
cache_mode=self.cache_mode,
save_precache=self.save_precache,
precache_location=self.precache_location,
cache_dir=self.cache_dir,
number_of_cross_validation_splits=self.number_of_cross_validation_splits,
cross_validation_split_index=self.cross_validation_split_index,
Expand All @@ -134,7 +139,7 @@ def create_model(self) -> DeepMILModule:
tile_size=self.tile_size,
level=self.level)

def get_slide_dataset(self) -> PandaDataset:
def get_slide_dataset(self) -> PandaDataset:
return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore

def get_callbacks(self) -> List[Callback]:
Expand Down Expand Up @@ -162,8 +167,7 @@ def get_path_to_best_checkpoint(self) -> Path:
if checkpoint_path.is_file():
return checkpoint_path

raise ValueError("Path to best checkpoint not found")

raise ValueError("Path to best checkpoint not found")

class PandaImageNetMIL(DeepSMILEPanda):
def __init__(self, **kwargs: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/utils/io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def load_labels_from_dataset_source(dataset_source: PatientDatasetSource, check_

def load_image(path: PathOrString, image_type: Optional[Type] = float) -> ImageWithHeader:
"""
Loads an image with extension numpy or nifti
Loads an image with extension numpy or nifti or png
For HDF5 path suffix
For images |<dataset_name>|<channel index>
For segmentation binary |<dataset_name>|<channel index>
Expand Down
Loading