diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a0c10bcc..1c964130e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ jobs that run in AzureML. - ([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets - ([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests - ([#621](https://github.com/microsoft/InnerEye-DeepLearning/pull/621)) Add WSI preprocessing functions and enable tiling more generic slide datasets +- ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 4caf352d4..827d23a03 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -9,17 +9,22 @@ import numpy as np from typing import Any, Callable, Dict, Optional, Tuple, List import torch +import matplotlib.pyplot as plt +import more_itertools as mi from pytorch_lightning import LightningModule from torch import Tensor, argmax, mode, nn, no_grad, optim, round from torchmetrics import AUROC, F1, Accuracy, Precision, Recall from InnerEye.Common import fixed_paths -from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset +from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset, SlidesDataset from InnerEye.ML.Histopathology.models.encoders import TileEncoder -from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist +from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_attention_tiles, plot_scores_hist, plot_heatmap_overlay, plot_slide from InnerEye.ML.Histopathology.utils.naming import ResultsKey +from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict +from InnerEye.ML.Histopathology.utils.naming import SlideKey + RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN] @@ -46,7 +51,9 @@ def __init__(self, weight_decay: float = 1e-4, adam_betas: Tuple[float, float] = (0.9, 0.99), verbose: bool = False, - ) -> None: + slide_dataset: SlidesDataset = None, + tile_size: int = 224, + level: int = 1) -> None: """ :param label_column: Label key for input batch dictionary. :param n_classes: Number of output classes for MIL prediction. @@ -61,6 +68,9 @@ def __init__(self, :param weight_decay: Weight decay parameter for L2 regularisation. :param adam_betas: Beta parameters for Adam optimiser. :param verbose: if True statements about memory usage are output at each step + :param slide_dataset: Slide dataset object, if available. + :param tile_size: The size of each tile (default=224). + :param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1). """ super().__init__() @@ -79,7 +89,13 @@ def __init__(self, self.weight_decay = weight_decay self.adam_betas = adam_betas + # Slide specific attributes + self.slide_dataset = slide_dataset + self.tile_size = tile_size + self.level = level + self.save_hyperparameters() + self.verbose = verbose self.aggregation_fn, self.num_pooling = self.get_pooling() @@ -288,29 +304,43 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore print("Selecting tiles ...") fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att')) fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att')) - tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highes_pred', 'highest_att')) + tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'highest_att')) tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'lowest_att')) report_cases = {'TP': [tp_top_tiles, tp_bottom_tiles], 'FN': [fn_top_tiles, fn_bottom_tiles]} for key in report_cases.keys(): - print(f"Plotting {key} ...") + print(f"Plotting {key} (tiles, thumbnails, attention heatmaps)...") key_folder_path = outputs_fig_path / f'{key}' Path(key_folder_path).mkdir(parents=True, exist_ok=True) nslides = len(report_cases[key][0]) for i in range(nslides): slide, score, paths, top_attn = report_cases[key][0][i] - fig = plot_slide_noxy(slide, score, paths, top_attn, key + '_top', ncols=4) - figpath = Path(key_folder_path, f'{slide}_top.png') - fig.savefig(figpath, bbox_inches='tight') + fig = plot_attention_tiles(slide, score, paths, top_attn, key + '_top', ncols=4) + self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_top.png')) slide, score, paths, bottom_attn = report_cases[key][1][i] - fig = plot_slide_noxy(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) - figpath = Path(key_folder_path, f'{slide}_bottom.png') - fig.savefig(figpath, bbox_inches='tight') + fig = plot_attention_tiles(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) + self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_bottom.png')) + + if self.slide_dataset is not None: + slide_dict = mi.first_true(self.slide_dataset, pred=lambda entry: entry[SlideKey.SLIDE_ID] == slide) # type: ignore + _ = load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore + slide_image = slide_dict[SlideKey.IMAGE] + location_bbox = slide_dict[SlideKey.LOCATION] + + fig = plot_slide(slide_image=slide_image, scale=1.0) + self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_thumbnail.png')) + fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results, + location_bbox=location_bbox, tile_size=self.tile_size, level=self.level) + self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_heatmap.png')) print("Plotting histogram ...") fig = plot_scores_hist(results) - fig.savefig(outputs_fig_path / 'hist_scores.png', bbox_inches='tight') + self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png') + + @staticmethod + def save_figure(fig: plt.figure, figpath: Path) -> None: + fig.savefig(figpath, bbox_inches='tight') @staticmethod def normalize_dict_for_df(dict_old: Dict[str, Any], use_gpu: bool) -> Dict: diff --git a/InnerEye/ML/Histopathology/utils/heatmap_utils.py b/InnerEye/ML/Histopathology/utils/heatmap_utils.py new file mode 100644 index 000000000..bbd74f75d --- /dev/null +++ b/InnerEye/ML/Histopathology/utils/heatmap_utils.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +from typing import List +import numpy as np + + +def location_selected_tiles(tile_coords: np.ndarray, + location_bbox: List[int], + level: int) -> np.ndarray: + """ Return the scaled and shifted tile co-ordinates for selected tiles in the slide. + :param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]) in original resolution. + :param location_bbox: Location of the bounding box on the slide in original resolution. + :param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available. + (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). + """ + level_dict = {0: 1, 1: 4, 2: 16} + factor = level_dict[level] + + x_tr, y_tr = location_bbox + tile_xs, tile_ys = tile_coords.T + tile_xs = tile_xs - x_tr + tile_ys = tile_ys - y_tr + tile_xs = tile_xs//factor + tile_ys = tile_ys//factor + + sel_coords = np.transpose([tile_xs.tolist(), tile_ys.tolist()]) + + return sel_coords diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index 834ac4182..b6f49f081 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -7,9 +7,13 @@ import torch import matplotlib.pyplot as plt from math import ceil +import numpy as np +import matplotlib.patches as patches +import matplotlib.collections as collection from InnerEye.ML.Histopathology.models.transforms import load_pil_image from InnerEye.ML.Histopathology.utils.naming import ResultsKey +from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1, @@ -75,7 +79,7 @@ def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.PROB, return fig -def plot_slide_noxy(slide: str, score: float, paths: List, attn: List, case: str, ncols: int = 5, +def plot_attention_tiles(slide: str, score: float, paths: List, attn: List, case: str, ncols: int = 5, size: Tuple = (10, 10)) -> plt.figure: """ :param slide: slide identifier @@ -97,3 +101,66 @@ def plot_slide_noxy(slide: str, score: float, paths: List, attn: List, case: str for i in range(len(axs.ravel())): axs.ravel()[i].set_axis_off() return fig + + +def plot_slide(slide_image: np.ndarray, scale: float) -> plt.figure: + """Plots a slide thumbnail from a given slide image and scale. + :param slide_image: Numpy array of the slide image (shape: [3, H, W]). + :return: matplotlib figure of the slide thumbnail. + """ + fig, ax = plt.subplots() + slide_image = slide_image.transpose(1, 2, 0) + ax.imshow(slide_image) + ax.set_axis_off() + original_size = fig.get_size_inches() + fig.set_size_inches((original_size[0]*scale, original_size[1]*scale)) + return fig + + +def plot_heatmap_overlay(slide: str, + slide_image: np.ndarray, + results: Dict[str, List[Any]], + location_bbox: List[int], + tile_size: int = 224, + level: int = 1) -> plt.figure: + """Plots heatmap of selected tiles (e.g. tiles in a bag) overlay on the corresponding slide. + :param slide: slide identifier. + :param slide_image: Numpy array of the slide image (shape: [3, H, W]). + :param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides. + :param tile_size: Size of each tile. Default 224. + :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). Default 1. + :param location_bbox: Location of the bounding box of the slide. + :return: matplotlib figure of the heatmap of the given tiles on slide. + """ + fig, ax = plt.subplots() + slide_image = slide_image.transpose(1, 2, 0) + ax.imshow(slide_image) + ax.set_xlim(0, slide_image.shape[1]) + ax.set_ylim(slide_image.shape[0], 0) + + coords = [] + slide_ids = [item[0] for item in results[ResultsKey.SLIDE_ID]] + slide_idx = slide_ids.index(slide) + attentions = results[ResultsKey.BAG_ATTN][slide_idx] + + # for each tile in the bag + for tile_idx in range(len(results[ResultsKey.IMAGE_PATH][slide_idx])): + tile_coords = np.transpose(np.array([results[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(), + results[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()])) + coords.append(tile_coords) + + coords = np.array(coords) + attentions = np.array(attentions.cpu()).reshape(-1) + + sel_coords = location_selected_tiles(tile_coords=coords, location_bbox=location_bbox, level=level) + cmap = plt.cm.get_cmap('Reds') + + tile_xs, tile_ys = sel_coords.T + rects = [patches.Rectangle(xy, tile_size, tile_size) for xy in zip(tile_xs, tile_ys)] + + pc = collection.PatchCollection(rects, match_original=True, cmap=cmap, alpha=.5, edgecolor=None) + pc.set_array(np.array(attentions)) + pc.set_clim([0, 1]) + ax.add_collection(pc) + plt.colorbar(pc, ax=ax) + return fig diff --git a/InnerEye/ML/Histopathology/utils/naming.py b/InnerEye/ML/Histopathology/utils/naming.py index 9f1b237df..5b7273ad9 100644 --- a/InnerEye/ML/Histopathology/utils/naming.py +++ b/InnerEye/ML/Histopathology/utils/naming.py @@ -18,6 +18,7 @@ class SlideKey(str, Enum): ORIGIN = 'origin' FOREGROUND_THRESHOLD = 'foreground_threshold' METADATA = 'metadata' + LOCATION = 'location' class TileKey(str, Enum): diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 155a09808..521711b17 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -16,6 +16,7 @@ from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer from InnerEye.ML.lightning_container import LightningContainer +from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, TilesDataModule from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder, @@ -100,3 +101,6 @@ def create_model(self) -> DeepMILModule: def get_data_module(self) -> TilesDataModule: raise NotImplementedError + + def get_slide_dataset(self) -> SlidesDataset: + raise NotImplementedError diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py index 5d75cb3d7..01384469b 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py @@ -14,11 +14,12 @@ damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405 """ from pathlib import Path -from typing import Any, Dict +from typing import Any, List import os from monai.transforms import Compose from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Callback from health_ml.networks.layers.attention_layers import GatedAttentionLayer from health_azure.utils import get_workspace @@ -125,9 +126,8 @@ def get_data_module(self) -> TilesDataModule: cross_validation_split_index=self.cross_validation_split_index, ) - def get_trainer_arguments(self) -> Dict[str, Any]: - # These arguments will be passed through to the Lightning trainer. - return {"callbacks": self.callbacks} + def get_callbacks(self) -> List[Callback]: + return super().get_callbacks() + [self.callbacks] def get_path_to_best_checkpoint(self) -> Path: """ diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 610b9aa61..59a617b74 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -3,14 +3,15 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Any, Dict +from typing import Any, List from pathlib import Path import os from monai.transforms import Compose from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Callback from health_azure.utils import CheckpointDownloader -from health_azure.utils import get_workspace +from health_azure.utils import get_workspace, is_running_in_azure_ml from health_ml.networks.layers.attention_layers import GatedAttentionLayer from InnerEye.Common import fixed_paths from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule @@ -26,8 +27,11 @@ ImageNetEncoder, ImageNetSimCLREncoder, InnerEyeSSLEncoder, + IdentityEncoder ) from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL +from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset +from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule class DeepSMILEPanda(BaseMIL): @@ -38,6 +42,8 @@ def __init__(self, **kwargs: Any) -> None: # declared in DatasetParams: local_dataset=Path("/tmp/datasets/PANDA_tiles"), azure_dataset_id="PANDA_tiles", + extra_azure_dataset_ids=["PANDA"], + extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")], # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: num_epochs=200, @@ -48,11 +54,12 @@ def __init__(self, **kwargs: Any) -> None: # declared in OptimizerParams: l_rate=5e-4, weight_decay=1e-4, - adam_betas=(0.9, 0.99), - ) + adam_betas=(0.9, 0.99)) default_kwargs.update(kwargs) super().__init__(**default_kwargs) super().__init__(**default_kwargs) + if not is_running_in_azure_ml(): + self.num_epochs = 1 self.best_checkpoint_filename = "checkpoint_max_val_auroc" self.best_checkpoint_filename_with_suffix = ( self.best_checkpoint_filename + ".ckpt" @@ -109,9 +116,29 @@ def get_data_module(self) -> PandaTilesDataModule: cross_validation_split_index=self.cross_validation_split_index, ) - def get_trainer_arguments(self) -> Dict[str, Any]: - # These arguments will be passed through to the Lightning trainer. - return {"callbacks": self.callbacks} + def create_model(self) -> DeepMILModule: + self.data_module = self.get_data_module() + # Encoding is done in the datamodule, so here we provide instead a dummy + # no-op IdentityEncoder to be used inside the model + self.slide_dataset = self.get_slide_dataset() + self.level = 1 + return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), + label_column=self.data_module.train_dataset.LABEL_COLUMN, + n_classes=self.data_module.train_dataset.N_CLASSES, + pooling_layer=self.get_pooling_layer(), + class_weights=self.data_module.class_weights, + l_rate=self.l_rate, + weight_decay=self.weight_decay, + adam_betas=self.adam_betas, + slide_dataset=self.get_slide_dataset(), + tile_size=self.tile_size, + level=self.level) + + def get_slide_dataset(self) -> PandaDataset: + return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore + + def get_callbacks(self) -> List[Callback]: + return super().get_callbacks() + [self.callbacks] def get_path_to_best_checkpoint(self) -> Path: """ @@ -135,7 +162,7 @@ def get_path_to_best_checkpoint(self) -> Path: if checkpoint_path.is_file(): return checkpoint_path - raise ValueError("Path to best checkpoint not found") + raise ValueError("Path to best checkpoint not found") class PandaImageNetMIL(DeepSMILEPanda): diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index a63884477..c41c70280 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -3,14 +3,24 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +from pathlib import Path + import math +import numpy as np from typing import List import matplotlib from torch.functional import Tensor +import pytest -from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles +from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, plot_heatmap_overlay from InnerEye.ML.Histopathology.utils.naming import ResultsKey +from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles +from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path +from InnerEye.Common.output_directories import OutputFolderForTests +from InnerEye.ML.plotting import resize_and_save +from InnerEye.ML.utils.ml_util import set_random_seed +from Tests.ML.util import assert_binary_files_match def assert_equal_lists(pred: List, expected: List) -> None: @@ -37,7 +47,17 @@ def assert_equal_lists(pred: List, expected: List) -> None: [Tensor([[0.1, 0.0, 0.2, 0.15]]), Tensor([[0.10, 0.18, 0.15, 0.13]]), Tensor([[0.25, 0.23, 0.20, 0.21]]), - Tensor([[0.33, 0.31, 0.37, 0.35]])] + Tensor([[0.33, 0.31, 0.37, 0.35]])], + ResultsKey.TILE_X: + [Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424])], + ResultsKey.TILE_Y: + [Tensor([200, 424, 200, 424]), + Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424])] } def test_select_k_tiles() -> None: @@ -57,6 +77,78 @@ def test_select_k_tiles() -> None: assert_equal_lists(bottom_tp, [(4, 1.0, [2, 1], [Tensor([0.31]), Tensor([0.33])]), (2, 0.7, [1, 4], [Tensor([0.10]), Tensor([0.13])])]) -def test_plot_scores_hist() -> None: +def test_plot_scores_hist(test_output_dirs: OutputFolderForTests) -> None: fig = plot_scores_hist(test_dict) assert isinstance(fig, matplotlib.figure.Figure) + file = Path(test_output_dirs.root_dir) / "plot_score_hist.png" + resize_and_save(5, 5, file) + assert file.exists() + expected = full_ml_test_data_path("histo_heatmaps") / "score_hist.png" + # To update the stored results, uncomment this line: + # expected.write_bytes(file.read_bytes()) + assert_binary_files_match(file, expected) + + +@pytest.mark.parametrize("scale", [0.1, 1.2, 2.4, 3.6]) +def test_plot_slide(test_output_dirs: OutputFolderForTests, scale: int) -> None: + set_random_seed(0) + slide_image = np.random.rand(3, 1000, 2000) + fig = plot_slide(slide_image=slide_image, scale=scale) + assert isinstance(fig, matplotlib.figure.Figure) + file = Path(test_output_dirs.root_dir) / "plot_slide.png" + resize_and_save(5, 5, file) + assert file.exists() + expected = full_ml_test_data_path("histo_heatmaps") / f"slide_{scale}.png" + # To update the stored results, uncomment this line: + # expected.write_bytes(file.read_bytes()) + assert_binary_files_match(file, expected) + + +def test_plot_heatmap_overlay(test_output_dirs: OutputFolderForTests) -> None: + set_random_seed(0) + slide_image = np.random.rand(3, 1000, 2000) + location_bbox = [100, 100] + slide = 1 + tile_size = 224 + level = 0 + fig = plot_heatmap_overlay(slide=slide, # type: ignore + slide_image=slide_image, + results=test_dict, # type: ignore + location_bbox=location_bbox, + tile_size=tile_size, + level=level) + assert isinstance(fig, matplotlib.figure.Figure) + file = Path(test_output_dirs.root_dir) / "plot_heatmap_overlay.png" + resize_and_save(5, 5, file) + assert file.exists() + expected = full_ml_test_data_path("histo_heatmaps") / "heatmap_overlay.png" + # To update the stored results, uncomment this line: + # expected.write_bytes(file.read_bytes()) + assert_binary_files_match(file, expected) + +@pytest.mark.parametrize("level", [0, 1, 2]) +def test_location_selected_tiles(level: int) -> None: + set_random_seed(0) + slide = 1 + location_bbox = [100, 100] + slide_image = np.random.rand(3, 1000, 2000) + + coords = [] + slide_ids = [item[0] for item in test_dict[ResultsKey.SLIDE_ID]] # type: ignore + slide_idx = slide_ids.index(slide) + for tile_idx in range(len(test_dict[ResultsKey.IMAGE_PATH][slide_idx])): # type: ignore + tile_coords = np.transpose(np.array([test_dict[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(), # type: ignore + test_dict[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()])) # type: ignore + coords.append(tile_coords) + + coords = np.array(coords) + tile_coords_transformed = location_selected_tiles(tile_coords=coords, + location_bbox=location_bbox, + level=level) + tile_xs, tile_ys = tile_coords_transformed.T + level_dict = {0: 1, 1: 4, 2: 16} + factor = level_dict[level] + assert min(tile_xs) >= 0 + assert max(tile_xs) <= slide_image.shape[2]//factor + assert min(tile_ys) >= 0 + assert max(tile_ys) <= slide_image.shape[1]//factor diff --git a/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png new file mode 100644 index 000000000..64cf40656 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ab9e04b3b6ff098ff0269f8c3ffbf1cad45c8c6635c68e7732a59ea9568e82f +size 314086 diff --git a/Tests/ML/test_data/histo_heatmaps/score_hist.png b/Tests/ML/test_data/histo_heatmaps/score_hist.png new file mode 100644 index 000000000..bced47d8b --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/score_hist.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca95c0017d0a51d75d118e54f21c1e907b3d90dcca822b23622e369267907198 +size 17057 diff --git a/Tests/ML/test_data/histo_heatmaps/slide_0.1.png b/Tests/ML/test_data/histo_heatmaps/slide_0.1.png new file mode 100644 index 000000000..964f81df6 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/slide_0.1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6867218d67f06f482f1ce4a5c102fee662e445b34c7d09121b2aba8380eb54a +size 474111 diff --git a/Tests/ML/test_data/histo_heatmaps/slide_1.2.png b/Tests/ML/test_data/histo_heatmaps/slide_1.2.png new file mode 100644 index 000000000..964f81df6 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/slide_1.2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6867218d67f06f482f1ce4a5c102fee662e445b34c7d09121b2aba8380eb54a +size 474111 diff --git a/Tests/ML/test_data/histo_heatmaps/slide_2.4.png b/Tests/ML/test_data/histo_heatmaps/slide_2.4.png new file mode 100644 index 000000000..964f81df6 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/slide_2.4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6867218d67f06f482f1ce4a5c102fee662e445b34c7d09121b2aba8380eb54a +size 474111 diff --git a/Tests/ML/test_data/histo_heatmaps/slide_3.6.png b/Tests/ML/test_data/histo_heatmaps/slide_3.6.png new file mode 100644 index 000000000..964f81df6 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/slide_3.6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6867218d67f06f482f1ce4a5c102fee662e445b34c7d09121b2aba8380eb54a +size 474111