From f654ce3307d38495e62ce526779aa0df612cbcee Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Thu, 6 Jan 2022 16:20:55 +0000 Subject: [PATCH 01/23] heatmap and thumbnail output, deepmil module for panda --- InnerEye/ML/Histopathology/models/deepmil.py | 112 ++++++++++++++- .../ML/Histopathology/utils/heatmap_utils.py | 135 ++++++++++++++++++ .../ML/Histopathology/utils/metrics_utils.py | 55 +++++++ .../classification/DeepSMILECrck.py | 4 +- .../classification/DeepSMILEPanda.py | 27 +++- 5 files changed, 326 insertions(+), 7 deletions(-) create mode 100644 InnerEye/ML/Histopathology/utils/heatmap_utils.py diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index e42eacb84..097498f69 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -17,9 +17,14 @@ from InnerEye.Common import fixed_paths from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset from InnerEye.ML.Histopathology.models.encoders import TileEncoder -from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist +from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist, plot_heatmap_slide, plot_slide from InnerEye.ML.Histopathology.utils.naming import ResultsKey +from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset +from monai.data.dataset import Dataset +from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict +from InnerEye.ML.Histopathology.utils.naming import SlideKey + RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN] @@ -105,7 +110,7 @@ def get_classifier(self) -> Callable: def get_loss(self) -> Callable: if self.n_classes > 1: - return nn.CrossEntropyLoss(weight=self.class_weights) + return nn.CrossEntropyLoss(weight=self.class_weights.float()) # type: ignore else: pos_weight = None if self.class_weights is not None: @@ -282,7 +287,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore print("Selecting tiles ...") fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att')) fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att')) - tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highes_pred', 'highest_att')) + tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'highest_att')) tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'lowest_att')) report_cases = {'TP': [tp_top_tiles, tp_bottom_tiles], 'FN': [fn_top_tiles, fn_bottom_tiles]} @@ -331,3 +336,104 @@ def move_list_to_device(list_encoded_features: List, use_gpu: bool) -> List: feature = feature.squeeze(0).to(device) features_list.append(feature) return features_list + + +class DeepMILModule_Panda(DeepMILModule): + """ + Child class of `DeepMILModule` for deep multiple-instance learning on PANDA dataset + """ + def __init__(self, + panda_dir: str, + tile_size: int = 224, + level: int = 1, + **kwargs: Any) -> None: + self.panda_dir = panda_dir + self.tile_size = tile_size + self.level = level + super().__init__(**kwargs) + + def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore + # outputs object consists of a list of dictionaries (of metadata and results, including encoded features) + # It can be indexed as outputs[batch_idx][batch_key][bag_idx][tile_idx] + # example of batch_key ResultsKey.SLIDE_ID_COL + # for batch keys that contains multiple values for slides e.g. ResultsKey.BAG_ATTN_COL + # outputs[batch_idx][batch_key][bag_idx][tile_idx] + # contains the tile value + + # collate the batches + results: Dict[str, List[Any]] = {} + [results.update({col: []}) for col in outputs[0].keys()] + for key in results.keys(): + for batch_id in range(len(outputs)): + results[key] += outputs[batch_id][key] + + print("Saving outputs ...") + # collate at slide level + list_slide_dicts = [] + list_encoded_features = [] + # any column can be used here, the assumption is that the first dimension is the N of slides + for slide_idx in range(len(results[ResultsKey.SLIDE_ID])): + slide_dict = dict() + for key in results.keys(): + if key not in [ResultsKey.IMAGE, ResultsKey.LOSS]: + slide_dict[key] = results[key][slide_idx] + list_slide_dicts.append(slide_dict) + list_encoded_features.append(results[ResultsKey.IMAGE][slide_idx]) + + print(f"Metrics results will be output to {fixed_paths.repository_root_directory()}/outputs") + csv_filename = fixed_paths.repository_root_directory() / Path('outputs/test_output.csv') + encoded_features_filename = fixed_paths.repository_root_directory() / Path('outputs/test_encoded_features.pickle') + + # Collect the list of dictionaries in a list of pandas dataframe and save + df_list = [] + for slide_dict in list_slide_dicts: + slide_dict = self.normalize_dict_for_df(slide_dict, use_gpu=False) + df_list.append(pd.DataFrame.from_dict(slide_dict)) + df = pd.concat(df_list, ignore_index=True) + df.to_csv(csv_filename, mode='w', header=True) + + # Collect all features in a list and save + features_list = self.move_list_to_device(list_encoded_features, use_gpu=False) + torch.save(features_list, encoded_features_filename) + + panda_dataset = Dataset(PandaDataset(root=self.panda_dir)) + + print("Selecting tiles ...") + fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att')) + fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att')) + tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'highest_att')) + tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'lowest_att')) + report_cases = {'TP': [tp_top_tiles, tp_bottom_tiles], 'FN': [fn_top_tiles, fn_bottom_tiles]} + + for key in report_cases.keys(): + print(f"Plotting {key} (tiles, thumbnails, attention heatmaps)...") + output_path = Path(fixed_paths.repository_root_directory(), f'outputs/fig/{key}/') + Path(output_path).mkdir(parents=True, exist_ok=True) + nslides = len(report_cases[key][0]) + for i in range(nslides): + slide, score, paths, top_attn = report_cases[key][0][i] + fig = plot_slide_noxy(slide, score, paths, top_attn, key + '_top', ncols=4) + figpath = Path(output_path, f'{slide}_top.png') + fig.savefig(figpath, bbox_inches='tight') + + slide, score, paths, bottom_attn = report_cases[key][1][i] + fig = plot_slide_noxy(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) + figpath = Path(output_path, f'{slide}_bottom.png') + fig.savefig(figpath, bbox_inches='tight') + + slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, panda_dataset))[0] # type: ignore + load_image_dict(slide_dict, level=self.level, margin=0) + slide_image = slide_dict[SlideKey.IMAGE] + + fig = plot_slide(slide_image=slide_image, scale=1.0) + figpath = Path(output_path, f'{slide}_thumbnail.png') + fig.savefig(figpath, bbox_inches='tight') + + fig = plot_heatmap_slide(slide=slide, slide_image=slide_image, results=results) + figpath = Path(output_path, f'{slide}_heatmap.png') + fig.savefig(figpath, bbox_inches='tight') + + print("Plotting histogram ...") + fig = plot_scores_hist(results) + output_path = Path(fixed_paths.repository_root_directory(), 'outputs/fig/hist_scores.png') + fig.savefig(output_path, bbox_inches='tight') diff --git a/InnerEye/ML/Histopathology/utils/heatmap_utils.py b/InnerEye/ML/Histopathology/utils/heatmap_utils.py new file mode 100644 index 000000000..985ea175b --- /dev/null +++ b/InnerEye/ML/Histopathology/utils/heatmap_utils.py @@ -0,0 +1,135 @@ +import io +from typing import Any, Optional, Sequence, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import PIL.Image +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from matplotlib.image import AxesImage + + +def assemble_heatmap(tile_coords: np.ndarray, tile_values: np.ndarray, tile_size: int, level: int, + fill_value: float = np.nan, pad: int = 0) -> Tuple[np.ndarray, np.ndarray]: + """Assembles a 2D heatmap from sequences of tile coordinates and values. + + :param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]). + :param tile_values: Scalar values of the tiles (shape: [N]). + :param tile_size: Size of each tile; must be >0. + :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). + :param fill_value: Value to assign to empty elements (default: `NaN`). + :param pad: If positive, pad the heatmap by `pad` elements on all sides (default: no padding). + :return: A tuple containing: + - `heatmap`: The 2D heatmap with the smallest dimensions to contain all given tiles, with + optional padding. + - `origin`: The lowest XY coordinates in the space of `tile_coords`. If `pad > 0`, this is + offset to match the padded margin. + """ + if tile_coords.shape[0] != tile_values.shape[0]: + raise ValueError(f"Tile coordinates and values must have the same length, " + f"got {tile_coords.shape[0]} and {tile_values.shape[0]}") + + level_dict = {"0": 1, "1": 4, "2": 16} + factor = level_dict[str(level)] + tile_coords_scaled = tile_coords//factor + tile_xs, tile_ys = tile_coords_scaled.T + + tile_xs = tile_xs - tile_size # top-left corner from top-right corner + x_min, x_max = min(tile_xs), max(tile_xs) + y_min, y_max = min(tile_ys), max(tile_ys) + + n_tiles_x = (x_max - x_min) // tile_size + 1 + n_tiles_y = (y_max - y_min) // tile_size + 1 + heatmap = np.full((n_tiles_y, n_tiles_x), fill_value) + + tile_js = (tile_xs - x_min) // tile_size + tile_is = (tile_ys - y_min) // tile_size + heatmap[tile_is, tile_js] = tile_values + origin = np.array([x_min, y_min]) + + if pad > 0: + heatmap = np.pad(heatmap, pad, mode='constant', constant_values=fill_value) + origin -= tile_size * pad # offset the origin to match the padded margin + + return heatmap, origin + + +def plot_heatmap(heatmap: np.ndarray, tile_size: int, origin: Sequence[int], ax: Optional[Axes] = None, **imshow_kwargs: Any) -> AxesImage: + """Plot a 2D heatmap to overlay on the slide. + + :param heatmap: The 2D scalar heatmap. + :param tile_size: Size of each tile. + :param origin: XY coordinates of the heatmap's top-left corner. + :param ax: Axes onto which to plot the heatmap (default: current axes). + :param imshow_kwargs: Kwargs for `plt.imshow()` (e.g. `alpha`, `cmap`, `interpolation`). + :return: The output of `plt.imshow()` to allow e.g. plotting a colorbar. + """ + if ax is None: + ax = plt.gca() + heatmap_width = tile_size * heatmap.shape[1] + heatmap_height = tile_size * heatmap.shape[0] + offset = tile_size * 0.5 + extent = ( + origin[0] - offset, # left + origin[0] + heatmap_width - offset, # right + origin[1] + heatmap_height - offset, # bottom + origin[1] - offset # top + ) + h = ax.imshow(heatmap, extent=extent, **imshow_kwargs) + cb = plt.colorbar(h, ax=ax) + cb.set_alpha(1) + cb.draw_all() + return ax + + +def figure_to_image(fig: Optional[Figure] = None, **savefig_kwargs: Any) -> PIL.Image: + """Converts a Matplotlib figure into an image for logging and display. + + :param fig: Input Matplotlib figure object (default: current figure). + :param savefig_kwargs: Kwargs for `fig.savefig()` (e.g. `dpi`, `bbox_inches`). + :return: Rasterised PIL image in RGBA format. + """ + if fig is None: + fig = plt.gcf() + buffer = io.BytesIO() + fig.savefig(buffer, format='png', **savefig_kwargs) + buffer.seek(0) + image = PIL.Image.open(buffer).convert('RGBA') + buffer.close() + return image + + +def plot_slide_from_tiles(tiles: np.array, + tile_coords: np.array, + level: int, tile_size: int, + width: int, height: int, + ax: Optional[Axes] = None, + **imshow_kwargs: Any) -> AxesImage: + """Reconstructs a slide given the tiles at a certain magnification. + + :param tiles: Tiles of the slide (shape: [N, H, W, C]). + :param tile_coords: Top-right coordinates of tiles (shape: [N, 2]). + :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). + :param tile_size: Size of each tile. + :param width: Width of slide. + :param height: Height of slide. + :param ax: Axes onto which to plot the heatmap (default: current axes). + """ + level_dict = {"0": 1, "1": 4, "2": 16} + factor = level_dict[str(level)] + tile_coords_scaled = tile_coords//factor + xs, ys = tile_coords_scaled.T + x_min = min(xs) - tile_size # top-left corner from top-right corner + y_min = min(ys) + offset = tile_size * 0.5 + if ax is None: + ax = plt.gca() + for i in range(tiles.shape[0]): + x, y = tile_coords_scaled[i] + x = x - tile_size # top-left corner from top-right corner + x = x - x_min + y = y - y_min + ax.imshow(tiles[i], extent=(x-offset, x+tile_size-offset, y+tile_size-offset, y-offset), **imshow_kwargs) + ax.set_xlim(0, width) + ax.set_ylim(height, 0) + return ax diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index 834ac4182..049b4bcaa 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -7,9 +7,11 @@ import torch import matplotlib.pyplot as plt from math import ceil +import numpy as np from InnerEye.ML.Histopathology.models.transforms import load_pil_image from InnerEye.ML.Histopathology.utils.naming import ResultsKey +from InnerEye.ML.Histopathology.utils.heatmap_utils import plot_heatmap, assemble_heatmap def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1, @@ -97,3 +99,56 @@ def plot_slide_noxy(slide: str, score: float, paths: List, attn: List, case: str for i in range(len(axs.ravel())): axs.ravel()[i].set_axis_off() return fig + + +def plot_slide(slide_image: np.array, scale: float) -> plt.figure: + """ + :param slide_image: Numpy array of the slide image (shape: [3, H, W]). + :return: matplotlib figure of the slide thumbnail. + """ + fig, ax = plt.subplots() + slide_image = slide_image.transpose(1, 2, 0) + ax.imshow(slide_image) + ax.set_axis_off() + original_size = fig.get_size_inches() + fig.set_size_inches((original_size[0]*scale, original_size[1]*scale)) + return fig + + +def plot_heatmap_slide(slide: str, slide_image: np.array, results: Dict, tile_size: int = 224, level: int = 1, origin: List = [0, 0]) -> plt.figure: + """ + :param slide: slide identifier. + :param slide_image: Numpy array of the slide image (shape: [3, H, W]). + :param results: List that contains slide_level dicts. + :param tile_size: Size of each tile. Default 224. + :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). + :param origin: XY coordinates of the heatmap's top-left corner. Default [0, 0]. + :return: matplotlib figure of the heatmap of the given tiles on slide. + """ + fig, ax = plt.subplots() + slide_image = slide_image.transpose(1, 2, 0) + tiles = [] + coords = [] + + slide_ids = [item[0] for item in results[ResultsKey.SLIDE_ID]] + slide_idx = slide_ids.index(slide) + attentions = results[ResultsKey.BAG_ATTN][slide_idx] + + # for each tile in the bag + for tile_idx in range(len(results[ResultsKey.IMAGE_PATH][slide_idx])): + image_path = results[ResultsKey.IMAGE_PATH][slide_idx][tile_idx] + tile_coords = np.transpose(np.array([results[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(), + results[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()])) + tiles.append(np.array(load_pil_image(image_path))) + coords.append(tile_coords) + + tiles = np.array(tiles) + coords = np.array(coords) + attentions = np.array(attentions.cpu()).reshape(-1) + ax.imshow(slide_image) + ax.set_xlim(0, slide_image.shape[1]) + ax.set_ylim(slide_image.shape[0], 0) + heatmap, _ = assemble_heatmap(tile_coords=coords, tile_values=attentions, tile_size=tile_size, level=level) + plot_heatmap(100*heatmap, tile_size, origin=origin, alpha=.5, clim=(0, 100)) + ax.set_title("Attention (%)") + return fig diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py index df2394f68..46cffae4c 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py @@ -109,7 +109,9 @@ def get_data_module(self) -> TilesDataModule: def get_trainer_arguments(self) -> Dict[str, Any]: # These arguments will be passed through to the Lightning trainer. - return {"callbacks": self.callbacks} + kw_args = super().get_trainer_arguments() + kw_args["callbacks"] = self.callbacks + return kw_args def get_path_to_best_checkpoint(self) -> Path: """ diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 17f61921f..026b42414 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -15,6 +15,7 @@ from InnerEye.Common import fixed_paths from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset +from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule_Panda from InnerEye.ML.Histopathology.models.transforms import ( EncodeTilesBatchd, @@ -28,7 +29,7 @@ ) from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint - +from InnerEye.ML.Histopathology.models.encoders import IdentityEncoder class DeepSMILEPanda(BaseMIL): def __init__(self, **kwargs: Any) -> None: @@ -40,7 +41,7 @@ def __init__(self, **kwargs: Any) -> None: azure_dataset_id="PANDA_tiles", # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: - num_epochs=200, + num_epochs=2, recovery_checkpoint_save_interval=10, recovery_checkpoints_save_last_k=-1, # declared in WorkflowParams: @@ -67,6 +68,8 @@ def __init__(self, **kwargs: Any) -> None: mode="max", ) self.callbacks = best_checkpoint_callback + self.slide_dir = "tmp/datasets/PANDA" + self.magnification_level = 1 @property def cache_dir(self) -> Path: @@ -108,9 +111,27 @@ def get_data_module(self) -> PandaTilesDataModule: cross_validation_split_index=self.cross_validation_split_index, ) + def create_model(self) -> DeepMILModule_Panda: + self.data_module = self.get_data_module() + # Encoding is done in the datamodule, so here we provide instead a dummy + # no-op IdentityEncoder to be used inside the model + return DeepMILModule_Panda(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), + label_column=self.data_module.train_dataset.LABEL_COLUMN, + n_classes=self.data_module.train_dataset.N_CLASSES, + pooling_layer=self.get_pooling_layer(), + class_weights=self.data_module.class_weights, + l_rate=self.l_rate, + weight_decay=self.weight_decay, + adam_betas=self.adam_betas, + panda_dir=self.slide_dir, + tile_size=self.tile_size, + level=self.magnification_level) + def get_trainer_arguments(self) -> Dict[str, Any]: # These arguments will be passed through to the Lightning trainer. - return {"callbacks": self.callbacks} + kw_args = super().get_trainer_arguments() + kw_args["callbacks"] = self.callbacks + return kw_args def get_path_to_best_checkpoint(self) -> Path: """ From 3de38e253b60fcfe26a68881966cf92251acfb6f Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Fri, 7 Jan 2022 16:39:32 +0000 Subject: [PATCH 02/23] deepmil module panda subclass --- InnerEye/ML/Histopathology/models/deepmil.py | 29 +++++++++++++---- .../ML/Histopathology/utils/metrics_utils.py | 2 ++ .../classification/DeepSMILEPanda.py | 31 +++++++++---------- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 097498f69..66487d2ef 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -21,6 +21,7 @@ from InnerEye.ML.Histopathology.utils.naming import ResultsKey from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset +from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset from monai.data.dataset import Dataset from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict from InnerEye.ML.Histopathology.utils.naming import SlideKey @@ -284,6 +285,11 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore features_list = self.move_list_to_device(list_encoded_features, use_gpu=False) torch.save(features_list, encoded_features_filename) + panda_dir = "/tmp/datasets/PANDA" + panda_tiles_dir = "/tmp/datasets/PANDA_tiles" + panda_dataset = Dataset(PandaDataset(root=panda_dir)) + panda_tiles_dataset = PandaTilesDataset(root=panda_tiles_dir) + print("Selecting tiles ...") fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att')) fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att')) @@ -292,7 +298,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore report_cases = {'TP': [tp_top_tiles, tp_bottom_tiles], 'FN': [fn_top_tiles, fn_bottom_tiles]} for key in report_cases.keys(): - print(f"Plotting {key} ...") + print(f"Plotting {key} (tiles, thumbnails, attention heatmaps)...") output_path = Path(fixed_paths.repository_root_directory(), f'outputs/fig/{key}/') Path(output_path).mkdir(parents=True, exist_ok=True) nslides = len(report_cases[key][0]) @@ -307,6 +313,18 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore figpath = Path(output_path, f'{slide}_bottom.png') fig.savefig(figpath, bbox_inches='tight') + slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, panda_dataset))[0] # type: ignore + load_image_dict(slide_dict, level=slide_dict['level'], margin=0) + slide_image = slide_dict[SlideKey.IMAGE] + + fig = plot_slide(slide_image=slide_image, scale=1.0) + figpath = Path(output_path, f'{slide}_thumbnail.png') + fig.savefig(figpath, bbox_inches='tight') + + fig = plot_heatmap_slide(slide=slide, slide_image=slide_image, results=results) + figpath = Path(output_path, f'{slide}_heatmap.png') + fig.savefig(figpath, bbox_inches='tight') + print("Plotting histogram ...") fig = plot_scores_hist(results) output_path = Path(fixed_paths.repository_root_directory(), 'outputs/fig/hist_scores.png') @@ -345,12 +363,11 @@ class DeepMILModule_Panda(DeepMILModule): def __init__(self, panda_dir: str, tile_size: int = 224, - level: int = 1, - **kwargs: Any) -> None: + **kwargs: Any) -> None: + super().__init__(**kwargs) self.panda_dir = panda_dir + print("in child class") self.tile_size = tile_size - self.level = level - super().__init__(**kwargs) def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore # outputs object consists of a list of dictionaries (of metadata and results, including encoded features) @@ -422,7 +439,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore fig.savefig(figpath, bbox_inches='tight') slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, panda_dataset))[0] # type: ignore - load_image_dict(slide_dict, level=self.level, margin=0) + load_image_dict(slide_dict, level=slide_dict['level'], margin=0) slide_image = slide_dict[SlideKey.IMAGE] fig = plot_slide(slide_image=slide_image, scale=1.0) diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index 049b4bcaa..6a4ca74ab 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -142,6 +142,8 @@ def plot_heatmap_slide(slide: str, slide_image: np.array, results: Dict, tile_si tiles.append(np.array(load_pil_image(image_path))) coords.append(tile_coords) + #for all tiles in the slide + tiles = np.array(tiles) coords = np.array(coords) attentions = np.array(attentions.cpu()).reshape(-1) diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 026b42414..d72d035d1 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -68,7 +68,7 @@ def __init__(self, **kwargs: Any) -> None: mode="max", ) self.callbacks = best_checkpoint_callback - self.slide_dir = "tmp/datasets/PANDA" + self.slide_dir = "/tmp/datasets/PANDA" self.magnification_level = 1 @property @@ -111,21 +111,20 @@ def get_data_module(self) -> PandaTilesDataModule: cross_validation_split_index=self.cross_validation_split_index, ) - def create_model(self) -> DeepMILModule_Panda: - self.data_module = self.get_data_module() - # Encoding is done in the datamodule, so here we provide instead a dummy - # no-op IdentityEncoder to be used inside the model - return DeepMILModule_Panda(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), - label_column=self.data_module.train_dataset.LABEL_COLUMN, - n_classes=self.data_module.train_dataset.N_CLASSES, - pooling_layer=self.get_pooling_layer(), - class_weights=self.data_module.class_weights, - l_rate=self.l_rate, - weight_decay=self.weight_decay, - adam_betas=self.adam_betas, - panda_dir=self.slide_dir, - tile_size=self.tile_size, - level=self.magnification_level) + # def create_model(self) -> DeepMILModule_Panda: + # self.data_module = self.get_data_module() + # # Encoding is done in the datamodule, so here we provide instead a dummy + # # no-op IdentityEncoder to be used inside the model + # return DeepMILModule_Panda(panda_dir=self.slide_dir, + # tile_size=self.tile_size, + # encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), + # label_column=self.data_module.train_dataset.LABEL_COLUMN, + # n_classes=self.data_module.train_dataset.N_CLASSES, + # pooling_layer=self.get_pooling_layer(), + # class_weights=self.data_module.class_weights, + # l_rate=self.l_rate, + # weight_decay=self.weight_decay, + # adam_betas=self.adam_betas) def get_trainer_arguments(self) -> Dict[str, Any]: # These arguments will be passed through to the Lightning trainer. From 5647e71b5e618b65c70ef886f191ce2216411125 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Mon, 10 Jan 2022 11:14:46 +0000 Subject: [PATCH 03/23] heatmap of selected tiles with correct location --- InnerEye/ML/Histopathology/models/deepmil.py | 11 +++-- .../ML/Histopathology/utils/heatmap_utils.py | 46 ++++++++++++++++++- .../ML/Histopathology/utils/metrics_utils.py | 23 ++++------ 3 files changed, 59 insertions(+), 21 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 66487d2ef..f702e7b34 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -21,7 +21,6 @@ from InnerEye.ML.Histopathology.utils.naming import ResultsKey from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset -from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset from monai.data.dataset import Dataset from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict from InnerEye.ML.Histopathology.utils.naming import SlideKey @@ -286,9 +285,9 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore torch.save(features_list, encoded_features_filename) panda_dir = "/tmp/datasets/PANDA" - panda_tiles_dir = "/tmp/datasets/PANDA_tiles" + tile_size = 224 + level = 1 panda_dataset = Dataset(PandaDataset(root=panda_dir)) - panda_tiles_dataset = PandaTilesDataset(root=panda_tiles_dir) print("Selecting tiles ...") fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att')) @@ -314,14 +313,16 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore fig.savefig(figpath, bbox_inches='tight') slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, panda_dataset))[0] # type: ignore - load_image_dict(slide_dict, level=slide_dict['level'], margin=0) + load_image_dict(slide_dict, level=level, margin=0) slide_image = slide_dict[SlideKey.IMAGE] fig = plot_slide(slide_image=slide_image, scale=1.0) figpath = Path(output_path, f'{slide}_thumbnail.png') fig.savefig(figpath, bbox_inches='tight') - fig = plot_heatmap_slide(slide=slide, slide_image=slide_image, results=results) + location_bbox = slide_dict['location'] + fig = plot_heatmap_slide(slide=slide, slide_image=slide_image, results=results, + location_bbox=location_bbox, tile_size=tile_size, level=slide_dict['level']) figpath = Path(output_path, f'{slide}_heatmap.png') fig.savefig(figpath, bbox_inches='tight') diff --git a/InnerEye/ML/Histopathology/utils/heatmap_utils.py b/InnerEye/ML/Histopathology/utils/heatmap_utils.py index 985ea175b..4d6782b35 100644 --- a/InnerEye/ML/Histopathology/utils/heatmap_utils.py +++ b/InnerEye/ML/Histopathology/utils/heatmap_utils.py @@ -1,5 +1,5 @@ import io -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, List import matplotlib.pyplot as plt import numpy as np @@ -7,6 +7,8 @@ from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.image import AxesImage +import matplotlib.patches as patches +import matplotlib.collections as collection def assemble_heatmap(tile_coords: np.ndarray, tile_values: np.ndarray, tile_size: int, level: int, @@ -133,3 +135,45 @@ def plot_slide_from_tiles(tiles: np.array, ax.set_xlim(0, width) ax.set_ylim(height, 0) return ax + + +def plot_heatmap_selected_tiles(tile_coords: np.array, + tile_values: np.ndarray, + location_bbox: List[int], + tile_size: int, + level: int, + ax: Optional[Axes] = None) -> AxesImage: + """Plots a 2D heatmap for selected tiles to overlay on the slide. + :param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]). + :param tile_values: Scalar values of the tiles (shape: [N]). + :param location_bbox: Location of the bounding box of the slide. + :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). + :param tile_size: Size of each tile. + :param ax: Axes onto which to plot the heatmap (default: current axes). + """ + if ax is None: + ax = plt.gca() + + level_dict = {"0": 1, "1": 4, "2": 16} + factor = level_dict[str(level)] + x_tr, y_tr = location_bbox + x_min = x_tr//factor + y_min = y_tr//factor + + tile_coords_scaled = tile_coords//factor + tile_xs, tile_ys = tile_coords_scaled.T + tile_xs = tile_xs - x_min + tile_ys = tile_ys - y_min + cmap = plt.cm.get_cmap('jet') + sel_coords = np.transpose([tile_xs.tolist(), tile_ys.tolist()]) + rects = [] + for i in range(sel_coords.shape[0]): + rect = patches.Rectangle((sel_coords[i][0], sel_coords[i][1]), tile_size, tile_size) + rects.append(rect) + + pc = collection.PatchCollection(rects, match_original=True, cmap=cmap, alpha=.5, edgecolor=None) + pc.set_array(np.array(tile_values)) + pc.set_clim([0, 1]) + ax.add_collection(pc) + plt.colorbar(pc, ax=ax) + return ax diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index 6a4ca74ab..709560a9c 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -11,7 +11,7 @@ from InnerEye.ML.Histopathology.models.transforms import load_pil_image from InnerEye.ML.Histopathology.utils.naming import ResultsKey -from InnerEye.ML.Histopathology.utils.heatmap_utils import plot_heatmap, assemble_heatmap +from InnerEye.ML.Histopathology.utils.heatmap_utils import plot_heatmap_selected_tiles def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1, @@ -102,7 +102,7 @@ def plot_slide_noxy(slide: str, score: float, paths: List, attn: List, case: str def plot_slide(slide_image: np.array, scale: float) -> plt.figure: - """ + """Plots a slide thumbnail from a given slide image and scale. :param slide_image: Numpy array of the slide image (shape: [3, H, W]). :return: matplotlib figure of the slide thumbnail. """ @@ -115,19 +115,18 @@ def plot_slide(slide_image: np.array, scale: float) -> plt.figure: return fig -def plot_heatmap_slide(slide: str, slide_image: np.array, results: Dict, tile_size: int = 224, level: int = 1, origin: List = [0, 0]) -> plt.figure: - """ +def plot_heatmap_slide(slide: str, slide_image: np.array, results: Dict, location_bbox: List[int], tile_size: int = 224, level: int = 1) -> plt.figure: + """Plots heatmap of selected tiles overlay on a slide. :param slide: slide identifier. :param slide_image: Numpy array of the slide image (shape: [3, H, W]). :param results: List that contains slide_level dicts. :param tile_size: Size of each tile. Default 224. - :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). - :param origin: XY coordinates of the heatmap's top-left corner. Default [0, 0]. + :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). Default 1. + :param location_bbox: Location of the bounding box of the slide. :return: matplotlib figure of the heatmap of the given tiles on slide. """ fig, ax = plt.subplots() slide_image = slide_image.transpose(1, 2, 0) - tiles = [] coords = [] slide_ids = [item[0] for item in results[ResultsKey.SLIDE_ID]] @@ -136,21 +135,15 @@ def plot_heatmap_slide(slide: str, slide_image: np.array, results: Dict, tile_si # for each tile in the bag for tile_idx in range(len(results[ResultsKey.IMAGE_PATH][slide_idx])): - image_path = results[ResultsKey.IMAGE_PATH][slide_idx][tile_idx] tile_coords = np.transpose(np.array([results[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(), results[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()])) - tiles.append(np.array(load_pil_image(image_path))) coords.append(tile_coords) - #for all tiles in the slide - - tiles = np.array(tiles) coords = np.array(coords) attentions = np.array(attentions.cpu()).reshape(-1) ax.imshow(slide_image) ax.set_xlim(0, slide_image.shape[1]) ax.set_ylim(slide_image.shape[0], 0) - heatmap, _ = assemble_heatmap(tile_coords=coords, tile_values=attentions, tile_size=tile_size, level=level) - plot_heatmap(100*heatmap, tile_size, origin=origin, alpha=.5, clim=(0, 100)) - ax.set_title("Attention (%)") + + plot_heatmap_selected_tiles(tile_coords=coords, tile_values=attentions, location_bbox=location_bbox, tile_size=tile_size, level=level) return fig From 3e2748ab887a24ab5af3f3dd6daefa1eadf3ae07 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Tue, 11 Jan 2022 10:11:54 +0000 Subject: [PATCH 04/23] scaled and shifted rectangle coordinates --- InnerEye/ML/Histopathology/utils/heatmap_utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/InnerEye/ML/Histopathology/utils/heatmap_utils.py b/InnerEye/ML/Histopathology/utils/heatmap_utils.py index 4d6782b35..03a457ce9 100644 --- a/InnerEye/ML/Histopathology/utils/heatmap_utils.py +++ b/InnerEye/ML/Histopathology/utils/heatmap_utils.py @@ -157,13 +157,12 @@ def plot_heatmap_selected_tiles(tile_coords: np.array, level_dict = {"0": 1, "1": 4, "2": 16} factor = level_dict[str(level)] x_tr, y_tr = location_bbox - x_min = x_tr//factor - y_min = y_tr//factor + tile_xs, tile_ys = tile_coords.T + tile_xs = tile_xs - x_tr + tile_ys = tile_ys - y_tr + tile_xs = tile_xs//factor + tile_ys = tile_ys//factor - tile_coords_scaled = tile_coords//factor - tile_xs, tile_ys = tile_coords_scaled.T - tile_xs = tile_xs - x_min - tile_ys = tile_ys - y_min cmap = plt.cm.get_cmap('jet') sel_coords = np.transpose([tile_xs.tolist(), tile_ys.tolist()]) rects = [] From e4f614fc9807fc8ab01ac2dad5ad27f7cf43ed01 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Tue, 11 Jan 2022 15:52:14 +0000 Subject: [PATCH 05/23] slide dataset at container level --- InnerEye/ML/Histopathology/models/deepmil.py | 145 +++--------------- .../classification/DeepSMILEPanda.py | 44 ++++-- 2 files changed, 52 insertions(+), 137 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index f702e7b34..a7cc0daed 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -20,7 +20,6 @@ from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist, plot_heatmap_slide, plot_slide from InnerEye.ML.Histopathology.utils.naming import ResultsKey -from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset from monai.data.dataset import Dataset from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict from InnerEye.ML.Histopathology.utils.naming import SlideKey @@ -51,6 +50,9 @@ def __init__(self, weight_decay: float = 1e-4, adam_betas: Tuple[float, float] = (0.9, 0.99), verbose: bool = False, + slide_dataset: Dataset = None, + tile_size: int = 224, + level: int = 1, ) -> None: """ :param label_column: Label key for input batch dictionary. @@ -66,6 +68,9 @@ def __init__(self, :param weight_decay: Weight decay parameter for L2 regularisation. :param adam_betas: Beta parameters for Adam optimiser. :param verbose: if True statements about memory usage are output at each step + :param slide_dataset: Slide dataset object, if available (default=None). + :param tile_size: The size of each tile (default=224). + :param level: The magnification at which tiles are available (default=1). """ super().__init__() @@ -83,6 +88,10 @@ def __init__(self, self.l_rate = l_rate self.weight_decay = weight_decay self.adam_betas = adam_betas + + self.slide_dataset = slide_dataset + self.tile_size = tile_size + self.level = level self.save_hyperparameters() self.verbose = verbose @@ -284,11 +293,6 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore features_list = self.move_list_to_device(list_encoded_features, use_gpu=False) torch.save(features_list, encoded_features_filename) - panda_dir = "/tmp/datasets/PANDA" - tile_size = 224 - level = 1 - panda_dataset = Dataset(PandaDataset(root=panda_dir)) - print("Selecting tiles ...") fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att')) fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att')) @@ -311,20 +315,21 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore fig = plot_slide_noxy(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) figpath = Path(output_path, f'{slide}_bottom.png') fig.savefig(figpath, bbox_inches='tight') - - slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, panda_dataset))[0] # type: ignore - load_image_dict(slide_dict, level=level, margin=0) - slide_image = slide_dict[SlideKey.IMAGE] - - fig = plot_slide(slide_image=slide_image, scale=1.0) - figpath = Path(output_path, f'{slide}_thumbnail.png') - fig.savefig(figpath, bbox_inches='tight') - - location_bbox = slide_dict['location'] - fig = plot_heatmap_slide(slide=slide, slide_image=slide_image, results=results, - location_bbox=location_bbox, tile_size=tile_size, level=slide_dict['level']) - figpath = Path(output_path, f'{slide}_heatmap.png') - fig.savefig(figpath, bbox_inches='tight') + print("length=", len(self.slide_dataset)) + if self.slide_dataset is not None: + slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, self.slide_dataset))[0] # type: ignore + load_image_dict(slide_dict, level=self.level, margin=0) + slide_image = slide_dict[SlideKey.IMAGE] + location_bbox = slide_dict['location'] + + fig = plot_slide(slide_image=slide_image, scale=1.0) + figpath = Path(output_path, f'{slide}_thumbnail.png') + fig.savefig(figpath, bbox_inches='tight') + + fig = plot_heatmap_slide(slide=slide, slide_image=slide_image, results=results, + location_bbox=location_bbox, tile_size=self.tile_size, level=slide_dict['level']) + figpath = Path(output_path, f'{slide}_heatmap.png') + fig.savefig(figpath, bbox_inches='tight') print("Plotting histogram ...") fig = plot_scores_hist(results) @@ -355,103 +360,3 @@ def move_list_to_device(list_encoded_features: List, use_gpu: bool) -> List: feature = feature.squeeze(0).to(device) features_list.append(feature) return features_list - - -class DeepMILModule_Panda(DeepMILModule): - """ - Child class of `DeepMILModule` for deep multiple-instance learning on PANDA dataset - """ - def __init__(self, - panda_dir: str, - tile_size: int = 224, - **kwargs: Any) -> None: - super().__init__(**kwargs) - self.panda_dir = panda_dir - print("in child class") - self.tile_size = tile_size - - def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore - # outputs object consists of a list of dictionaries (of metadata and results, including encoded features) - # It can be indexed as outputs[batch_idx][batch_key][bag_idx][tile_idx] - # example of batch_key ResultsKey.SLIDE_ID_COL - # for batch keys that contains multiple values for slides e.g. ResultsKey.BAG_ATTN_COL - # outputs[batch_idx][batch_key][bag_idx][tile_idx] - # contains the tile value - - # collate the batches - results: Dict[str, List[Any]] = {} - [results.update({col: []}) for col in outputs[0].keys()] - for key in results.keys(): - for batch_id in range(len(outputs)): - results[key] += outputs[batch_id][key] - - print("Saving outputs ...") - # collate at slide level - list_slide_dicts = [] - list_encoded_features = [] - # any column can be used here, the assumption is that the first dimension is the N of slides - for slide_idx in range(len(results[ResultsKey.SLIDE_ID])): - slide_dict = dict() - for key in results.keys(): - if key not in [ResultsKey.IMAGE, ResultsKey.LOSS]: - slide_dict[key] = results[key][slide_idx] - list_slide_dicts.append(slide_dict) - list_encoded_features.append(results[ResultsKey.IMAGE][slide_idx]) - - print(f"Metrics results will be output to {fixed_paths.repository_root_directory()}/outputs") - csv_filename = fixed_paths.repository_root_directory() / Path('outputs/test_output.csv') - encoded_features_filename = fixed_paths.repository_root_directory() / Path('outputs/test_encoded_features.pickle') - - # Collect the list of dictionaries in a list of pandas dataframe and save - df_list = [] - for slide_dict in list_slide_dicts: - slide_dict = self.normalize_dict_for_df(slide_dict, use_gpu=False) - df_list.append(pd.DataFrame.from_dict(slide_dict)) - df = pd.concat(df_list, ignore_index=True) - df.to_csv(csv_filename, mode='w', header=True) - - # Collect all features in a list and save - features_list = self.move_list_to_device(list_encoded_features, use_gpu=False) - torch.save(features_list, encoded_features_filename) - - panda_dataset = Dataset(PandaDataset(root=self.panda_dir)) - - print("Selecting tiles ...") - fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att')) - fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att')) - tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'highest_att')) - tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'lowest_att')) - report_cases = {'TP': [tp_top_tiles, tp_bottom_tiles], 'FN': [fn_top_tiles, fn_bottom_tiles]} - - for key in report_cases.keys(): - print(f"Plotting {key} (tiles, thumbnails, attention heatmaps)...") - output_path = Path(fixed_paths.repository_root_directory(), f'outputs/fig/{key}/') - Path(output_path).mkdir(parents=True, exist_ok=True) - nslides = len(report_cases[key][0]) - for i in range(nslides): - slide, score, paths, top_attn = report_cases[key][0][i] - fig = plot_slide_noxy(slide, score, paths, top_attn, key + '_top', ncols=4) - figpath = Path(output_path, f'{slide}_top.png') - fig.savefig(figpath, bbox_inches='tight') - - slide, score, paths, bottom_attn = report_cases[key][1][i] - fig = plot_slide_noxy(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) - figpath = Path(output_path, f'{slide}_bottom.png') - fig.savefig(figpath, bbox_inches='tight') - - slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, panda_dataset))[0] # type: ignore - load_image_dict(slide_dict, level=slide_dict['level'], margin=0) - slide_image = slide_dict[SlideKey.IMAGE] - - fig = plot_slide(slide_image=slide_image, scale=1.0) - figpath = Path(output_path, f'{slide}_thumbnail.png') - fig.savefig(figpath, bbox_inches='tight') - - fig = plot_heatmap_slide(slide=slide, slide_image=slide_image, results=results) - figpath = Path(output_path, f'{slide}_heatmap.png') - fig.savefig(figpath, bbox_inches='tight') - - print("Plotting histogram ...") - fig = plot_scores_hist(results) - output_path = Path(fixed_paths.repository_root_directory(), 'outputs/fig/hist_scores.png') - fig.savefig(output_path, bbox_inches='tight') diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index d72d035d1..9068fc380 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -15,7 +15,7 @@ from InnerEye.Common import fixed_paths from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset -from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule_Panda +from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule from InnerEye.ML.Histopathology.models.transforms import ( EncodeTilesBatchd, @@ -30,6 +30,9 @@ from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint from InnerEye.ML.Histopathology.models.encoders import IdentityEncoder +from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset +from monai.data.dataset import Dataset + class DeepSMILEPanda(BaseMIL): def __init__(self, **kwargs: Any) -> None: @@ -41,7 +44,7 @@ def __init__(self, **kwargs: Any) -> None: azure_dataset_id="PANDA_tiles", # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: - num_epochs=2, + num_epochs=1, recovery_checkpoint_save_interval=10, recovery_checkpoints_save_last_k=-1, # declared in WorkflowParams: @@ -51,6 +54,8 @@ def __init__(self, **kwargs: Any) -> None: l_rate=5e-4, weight_decay=1e-4, adam_betas=(0.9, 0.99), + extra_azure_dataset_ids=["PANDA"], + extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")] ) default_kwargs.update(kwargs) super().__init__(**default_kwargs) @@ -68,8 +73,11 @@ def __init__(self, **kwargs: Any) -> None: mode="max", ) self.callbacks = best_checkpoint_callback - self.slide_dir = "/tmp/datasets/PANDA" + self.slide_dir = str(self.extra_local_dataset_paths[0]) + print("slide_dir=", self.slide_dir) self.magnification_level = 1 + self.panda_slide_dataset = Dataset(PandaDataset(root=self.slide_dir)) + print("length=", len(self.panda_slide_dataset)) @property def cache_dir(self) -> Path: @@ -111,20 +119,22 @@ def get_data_module(self) -> PandaTilesDataModule: cross_validation_split_index=self.cross_validation_split_index, ) - # def create_model(self) -> DeepMILModule_Panda: - # self.data_module = self.get_data_module() - # # Encoding is done in the datamodule, so here we provide instead a dummy - # # no-op IdentityEncoder to be used inside the model - # return DeepMILModule_Panda(panda_dir=self.slide_dir, - # tile_size=self.tile_size, - # encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), - # label_column=self.data_module.train_dataset.LABEL_COLUMN, - # n_classes=self.data_module.train_dataset.N_CLASSES, - # pooling_layer=self.get_pooling_layer(), - # class_weights=self.data_module.class_weights, - # l_rate=self.l_rate, - # weight_decay=self.weight_decay, - # adam_betas=self.adam_betas) + def create_model(self) -> DeepMILModule: + self.data_module = self.get_data_module() + print("in create model") + # Encoding is done in the datamodule, so here we provide instead a dummy + # no-op IdentityEncoder to be used inside the model + return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), + label_column=self.data_module.train_dataset.LABEL_COLUMN, + n_classes=self.data_module.train_dataset.N_CLASSES, + pooling_layer=self.get_pooling_layer(), + class_weights=self.data_module.class_weights, + l_rate=self.l_rate, + weight_decay=self.weight_decay, + adam_betas=self.adam_betas, + slide_dataset=self.panda_slide_dataset, + tile_size=self.tile_size, + level=self.magnification_level) def get_trainer_arguments(self) -> Dict[str, Any]: # These arguments will be passed through to the Lightning trainer. From 9d303eab5fbe8e8990fb0a5155ef1e9fea7f1d3b Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Wed, 12 Jan 2022 10:11:40 +0000 Subject: [PATCH 06/23] Instantiate dataset class in basemil --- InnerEye/ML/Histopathology/models/deepmil.py | 16 ++++----- .../histo_configs/classification/BaseMIL.py | 20 +++++++++-- .../classification/DeepSMILEPanda.py | 34 ++++--------------- 3 files changed, 32 insertions(+), 38 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index a7cc0daed..741a901d0 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -42,6 +42,7 @@ def __init__(self, label_column: str, n_classes: int, encoder: TileEncoder, + slide_dataset: Dataset, pooling_layer: Callable[[int, int, int], nn.Module], pool_hidden_dim: int = 128, pool_out_dim: int = 1, @@ -50,10 +51,8 @@ def __init__(self, weight_decay: float = 1e-4, adam_betas: Tuple[float, float] = (0.9, 0.99), verbose: bool = False, - slide_dataset: Dataset = None, tile_size: int = 224, - level: int = 1, - ) -> None: + level: int = 1) -> None: """ :param label_column: Label key for input batch dictionary. :param n_classes: Number of output classes for MIL prediction. @@ -68,9 +67,9 @@ def __init__(self, :param weight_decay: Weight decay parameter for L2 regularisation. :param adam_betas: Beta parameters for Adam optimiser. :param verbose: if True statements about memory usage are output at each step - :param slide_dataset: Slide dataset object, if available (default=None). + :param slide_dataset: Slide dataset object, if available. :param tile_size: The size of each tile (default=224). - :param level: The magnification at which tiles are available (default=1). + :param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1). """ super().__init__() @@ -88,12 +87,14 @@ def __init__(self, self.l_rate = l_rate self.weight_decay = weight_decay self.adam_betas = adam_betas - + + # Slide specific attributes self.slide_dataset = slide_dataset self.tile_size = tile_size self.level = level self.save_hyperparameters() + self.verbose = verbose self.aggregation_fn, self.num_pooling = self.get_pooling() @@ -315,8 +316,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore fig = plot_slide_noxy(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) figpath = Path(output_path, f'{slide}_bottom.png') fig.savefig(figpath, bbox_inches='tight') - print("length=", len(self.slide_dataset)) - if self.slide_dataset is not None: + if len(self.slide_dataset) > 0: slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, self.slide_dataset))[0] # type: ignore load_image_dict(slide_dict, level=self.level, margin=0) slide_image = slide_dict[SlideKey.IMAGE] diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 29ce29da8..87bc57eea 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -9,11 +9,12 @@ """ import os from pathlib import Path -from typing import Type +from typing import Type, Optional import param from torch import nn from torchvision.models.resnet import resnet18 +from monai.data.dataset import Dataset from health_azure.utils import CheckpointDownloader, get_workspace from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer @@ -24,6 +25,7 @@ from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder, ImageNetEncoder, ImageNetSimCLREncoder, InnerEyeSSLEncoder, TileEncoder) +from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset class BaseMIL(LightningContainer): @@ -50,6 +52,11 @@ class BaseMIL(LightningContainer): "dataset upfront and save it to disk.") # local_dataset (used as data module root_path) is declared in DatasetParams superclass + # slide dataset parameters: + slide_datatype: Optional[str] = param.String(doc="Name of the slide dataset class if available.") + slide_path: Optional[Path] = param.ClassSelector(class_=Path, default=None, allow_None=True, doc="Path of the slide dataset if available.") + level: Optional[int] = param.Integer(doc="Downsampling level (e.g. 0, 1, 2) of the tiles if available.") + @property def cache_dir(self) -> Path: raise NotImplementedError @@ -95,6 +102,12 @@ def get_pooling_layer(self) -> Type[nn.Module]: else: raise ValueError(f"Unsupported pooling type: {self.pooling_type}") + def get_slide_dataset(self) -> Dataset: + if self.slide_datatype == PandaDataset.__name__: + return Dataset(PandaDataset(root=self.slide_path)) + else: + raise ValueError(f"Unsupported slide datatype: {self.slide_datatype}") + def create_model(self) -> DeepMILModule: self.data_module = self.get_data_module() # Encoding is done in the datamodule, so here we provide instead a dummy @@ -106,7 +119,10 @@ def create_model(self) -> DeepMILModule: class_weights=self.data_module.class_weights, l_rate=self.l_rate, weight_decay=self.weight_decay, - adam_betas=self.adam_betas) + adam_betas=self.adam_betas, + slide_dataset=self.get_slide_dataset(), + tile_size=self.tile_size, + level=self.level) def get_data_module(self) -> TilesDataModule: raise NotImplementedError diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 9068fc380..666c437bf 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -29,9 +29,7 @@ ) from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint -from InnerEye.ML.Histopathology.models.encoders import IdentityEncoder from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset -from monai.data.dataset import Dataset class DeepSMILEPanda(BaseMIL): @@ -39,9 +37,14 @@ def __init__(self, **kwargs: Any) -> None: default_kwargs = dict( # declared in BaseMIL: pooling_type=GatedAttentionLayer.__name__, + slide_datatype=PandaDataset.__name__, + slide_path=Path("/tmp/datasets/PANDA"), + level=1, # declared in DatasetParams: local_dataset=Path("/tmp/datasets/PANDA_tiles"), azure_dataset_id="PANDA_tiles", + extra_azure_dataset_ids=["PANDA"], + extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")], # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: num_epochs=1, @@ -53,10 +56,7 @@ def __init__(self, **kwargs: Any) -> None: # declared in OptimizerParams: l_rate=5e-4, weight_decay=1e-4, - adam_betas=(0.9, 0.99), - extra_azure_dataset_ids=["PANDA"], - extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")] - ) + adam_betas=(0.9, 0.99)) default_kwargs.update(kwargs) super().__init__(**default_kwargs) super().__init__(**default_kwargs) @@ -73,11 +73,6 @@ def __init__(self, **kwargs: Any) -> None: mode="max", ) self.callbacks = best_checkpoint_callback - self.slide_dir = str(self.extra_local_dataset_paths[0]) - print("slide_dir=", self.slide_dir) - self.magnification_level = 1 - self.panda_slide_dataset = Dataset(PandaDataset(root=self.slide_dir)) - print("length=", len(self.panda_slide_dataset)) @property def cache_dir(self) -> Path: @@ -119,23 +114,6 @@ def get_data_module(self) -> PandaTilesDataModule: cross_validation_split_index=self.cross_validation_split_index, ) - def create_model(self) -> DeepMILModule: - self.data_module = self.get_data_module() - print("in create model") - # Encoding is done in the datamodule, so here we provide instead a dummy - # no-op IdentityEncoder to be used inside the model - return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), - label_column=self.data_module.train_dataset.LABEL_COLUMN, - n_classes=self.data_module.train_dataset.N_CLASSES, - pooling_layer=self.get_pooling_layer(), - class_weights=self.data_module.class_weights, - l_rate=self.l_rate, - weight_decay=self.weight_decay, - adam_betas=self.adam_betas, - slide_dataset=self.panda_slide_dataset, - tile_size=self.tile_size, - level=self.magnification_level) - def get_trainer_arguments(self) -> Dict[str, Any]: # These arguments will be passed through to the Lightning trainer. kw_args = super().get_trainer_arguments() From 625db8a127fe86aa1acc92fe0b35fcce4aabaeca Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Thu, 13 Jan 2022 11:14:59 +0000 Subject: [PATCH 07/23] dataset paths for AML --- .../histo_configs/classification/BaseMIL.py | 7 +++-- .../classification/DeepSMILEPanda.py | 28 ++++++++++++++----- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 87bc57eea..9500db47e 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -53,8 +53,7 @@ class BaseMIL(LightningContainer): # local_dataset (used as data module root_path) is declared in DatasetParams superclass # slide dataset parameters: - slide_datatype: Optional[str] = param.String(doc="Name of the slide dataset class if available.") - slide_path: Optional[Path] = param.ClassSelector(class_=Path, default=None, allow_None=True, doc="Path of the slide dataset if available.") + slide_datatype: Optional[str] = param.String(default=None, allow_None=True, doc="Name of the slide dataset class if available.") level: Optional[int] = param.Integer(doc="Downsampling level (e.g. 0, 1, 2) of the tiles if available.") @property @@ -104,7 +103,9 @@ def get_pooling_layer(self) -> Type[nn.Module]: def get_slide_dataset(self) -> Dataset: if self.slide_datatype == PandaDataset.__name__: - return Dataset(PandaDataset(root=self.slide_path)) + return Dataset(PandaDataset(root=self.extra_local_dataset_paths[0])) + elif self.slide_datatype is None: + return Dataset(data=[]) else: raise ValueError(f"Unsupported slide datatype: {self.slide_datatype}") diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 666c437bf..d34970fb6 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -3,7 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Any, Dict +from typing import Any, Dict, Optional from pathlib import Path import os from monai.transforms import Compose @@ -31,6 +31,21 @@ from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset +local_mode = False +path_local_data: Optional[Path] +if local_mode: + path_local_data = Path("/tmp/datasets/PANDA_tiles") + azure_dataset_id = "Dummy" + extra_local_dataset_paths = [Path("/tmp/datasets/PANDA")] + extra_azure_dataset_ids = ["Dummy"] + num_epochs = 1 +else: + path_local_data = None + azure_dataset_id = "PANDA_tiles" + extra_local_dataset_paths = [] + extra_azure_dataset_ids = ["PANDA"] + num_epochs = 1 + class DeepSMILEPanda(BaseMIL): def __init__(self, **kwargs: Any) -> None: @@ -38,16 +53,15 @@ def __init__(self, **kwargs: Any) -> None: # declared in BaseMIL: pooling_type=GatedAttentionLayer.__name__, slide_datatype=PandaDataset.__name__, - slide_path=Path("/tmp/datasets/PANDA"), level=1, # declared in DatasetParams: - local_dataset=Path("/tmp/datasets/PANDA_tiles"), - azure_dataset_id="PANDA_tiles", - extra_azure_dataset_ids=["PANDA"], - extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")], + local_dataset=path_local_data, + azure_dataset_id=azure_dataset_id, + extra_azure_dataset_ids=extra_azure_dataset_ids, + extra_local_dataset_paths=extra_local_dataset_paths, # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: - num_epochs=1, + num_epochs=num_epochs, recovery_checkpoint_save_interval=10, recovery_checkpoints_save_last_k=-1, # declared in WorkflowParams: From 1deb0b9e2b50c0e75fca83feec7747149e7fd29a Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Thu, 13 Jan 2022 15:42:04 +0000 Subject: [PATCH 08/23] heatmap utils selected --- .../ML/Histopathology/utils/heatmap_utils.py | 133 +----------------- .../classification/DeepSMILEPanda.py | 2 +- 2 files changed, 3 insertions(+), 132 deletions(-) diff --git a/InnerEye/ML/Histopathology/utils/heatmap_utils.py b/InnerEye/ML/Histopathology/utils/heatmap_utils.py index 03a457ce9..83763fe9f 100644 --- a/InnerEye/ML/Histopathology/utils/heatmap_utils.py +++ b/InnerEye/ML/Histopathology/utils/heatmap_utils.py @@ -1,142 +1,13 @@ -import io -from typing import Any, Optional, Sequence, Tuple, List - +from typing import Optional, List import matplotlib.pyplot as plt import numpy as np -import PIL.Image + from matplotlib.axes import Axes -from matplotlib.figure import Figure from matplotlib.image import AxesImage import matplotlib.patches as patches import matplotlib.collections as collection -def assemble_heatmap(tile_coords: np.ndarray, tile_values: np.ndarray, tile_size: int, level: int, - fill_value: float = np.nan, pad: int = 0) -> Tuple[np.ndarray, np.ndarray]: - """Assembles a 2D heatmap from sequences of tile coordinates and values. - - :param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]). - :param tile_values: Scalar values of the tiles (shape: [N]). - :param tile_size: Size of each tile; must be >0. - :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). - :param fill_value: Value to assign to empty elements (default: `NaN`). - :param pad: If positive, pad the heatmap by `pad` elements on all sides (default: no padding). - :return: A tuple containing: - - `heatmap`: The 2D heatmap with the smallest dimensions to contain all given tiles, with - optional padding. - - `origin`: The lowest XY coordinates in the space of `tile_coords`. If `pad > 0`, this is - offset to match the padded margin. - """ - if tile_coords.shape[0] != tile_values.shape[0]: - raise ValueError(f"Tile coordinates and values must have the same length, " - f"got {tile_coords.shape[0]} and {tile_values.shape[0]}") - - level_dict = {"0": 1, "1": 4, "2": 16} - factor = level_dict[str(level)] - tile_coords_scaled = tile_coords//factor - tile_xs, tile_ys = tile_coords_scaled.T - - tile_xs = tile_xs - tile_size # top-left corner from top-right corner - x_min, x_max = min(tile_xs), max(tile_xs) - y_min, y_max = min(tile_ys), max(tile_ys) - - n_tiles_x = (x_max - x_min) // tile_size + 1 - n_tiles_y = (y_max - y_min) // tile_size + 1 - heatmap = np.full((n_tiles_y, n_tiles_x), fill_value) - - tile_js = (tile_xs - x_min) // tile_size - tile_is = (tile_ys - y_min) // tile_size - heatmap[tile_is, tile_js] = tile_values - origin = np.array([x_min, y_min]) - - if pad > 0: - heatmap = np.pad(heatmap, pad, mode='constant', constant_values=fill_value) - origin -= tile_size * pad # offset the origin to match the padded margin - - return heatmap, origin - - -def plot_heatmap(heatmap: np.ndarray, tile_size: int, origin: Sequence[int], ax: Optional[Axes] = None, **imshow_kwargs: Any) -> AxesImage: - """Plot a 2D heatmap to overlay on the slide. - - :param heatmap: The 2D scalar heatmap. - :param tile_size: Size of each tile. - :param origin: XY coordinates of the heatmap's top-left corner. - :param ax: Axes onto which to plot the heatmap (default: current axes). - :param imshow_kwargs: Kwargs for `plt.imshow()` (e.g. `alpha`, `cmap`, `interpolation`). - :return: The output of `plt.imshow()` to allow e.g. plotting a colorbar. - """ - if ax is None: - ax = plt.gca() - heatmap_width = tile_size * heatmap.shape[1] - heatmap_height = tile_size * heatmap.shape[0] - offset = tile_size * 0.5 - extent = ( - origin[0] - offset, # left - origin[0] + heatmap_width - offset, # right - origin[1] + heatmap_height - offset, # bottom - origin[1] - offset # top - ) - h = ax.imshow(heatmap, extent=extent, **imshow_kwargs) - cb = plt.colorbar(h, ax=ax) - cb.set_alpha(1) - cb.draw_all() - return ax - - -def figure_to_image(fig: Optional[Figure] = None, **savefig_kwargs: Any) -> PIL.Image: - """Converts a Matplotlib figure into an image for logging and display. - - :param fig: Input Matplotlib figure object (default: current figure). - :param savefig_kwargs: Kwargs for `fig.savefig()` (e.g. `dpi`, `bbox_inches`). - :return: Rasterised PIL image in RGBA format. - """ - if fig is None: - fig = plt.gcf() - buffer = io.BytesIO() - fig.savefig(buffer, format='png', **savefig_kwargs) - buffer.seek(0) - image = PIL.Image.open(buffer).convert('RGBA') - buffer.close() - return image - - -def plot_slide_from_tiles(tiles: np.array, - tile_coords: np.array, - level: int, tile_size: int, - width: int, height: int, - ax: Optional[Axes] = None, - **imshow_kwargs: Any) -> AxesImage: - """Reconstructs a slide given the tiles at a certain magnification. - - :param tiles: Tiles of the slide (shape: [N, H, W, C]). - :param tile_coords: Top-right coordinates of tiles (shape: [N, 2]). - :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). - :param tile_size: Size of each tile. - :param width: Width of slide. - :param height: Height of slide. - :param ax: Axes onto which to plot the heatmap (default: current axes). - """ - level_dict = {"0": 1, "1": 4, "2": 16} - factor = level_dict[str(level)] - tile_coords_scaled = tile_coords//factor - xs, ys = tile_coords_scaled.T - x_min = min(xs) - tile_size # top-left corner from top-right corner - y_min = min(ys) - offset = tile_size * 0.5 - if ax is None: - ax = plt.gca() - for i in range(tiles.shape[0]): - x, y = tile_coords_scaled[i] - x = x - tile_size # top-left corner from top-right corner - x = x - x_min - y = y - y_min - ax.imshow(tiles[i], extent=(x-offset, x+tile_size-offset, y+tile_size-offset, y-offset), **imshow_kwargs) - ax.set_xlim(0, width) - ax.set_ylim(height, 0) - return ax - - def plot_heatmap_selected_tiles(tile_coords: np.array, tile_values: np.ndarray, location_bbox: List[int], diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 977f17123..8630582ff 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -45,7 +45,7 @@ azure_dataset_id = "PANDA_tiles" extra_local_dataset_paths = [] extra_azure_dataset_ids = ["PANDA"] - num_epochs = 1 + num_epochs = 100 class DeepSMILEPanda(BaseMIL): From 5879ff832f8017a0eda5f0313933605138e34d1d Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Thu, 13 Jan 2022 17:31:30 +0000 Subject: [PATCH 09/23] fix mypy and flake8 errors --- CHANGELOG.md | 1 + InnerEye/ML/Histopathology/models/deepmil.py | 4 ++-- InnerEye/ML/configs/histo_configs/classification/BaseMIL.py | 2 +- .../ML/configs/histo_configs/classification/DeepSMILEPanda.py | 1 - 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a28df2c5a..be9ca5c28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ jobs that run in AzureML. - ([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets - ([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests - ([#621](https://github.com/microsoft/InnerEye-DeepLearning/pull/621)) Add WSI preprocessing functions and enable tiling more generic slide datasets +- ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index a805e6684..72967b446 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -42,7 +42,6 @@ def __init__(self, label_column: str, n_classes: int, encoder: TileEncoder, - slide_dataset: Dataset, pooling_layer: Callable[[int, int, int], nn.Module], pool_hidden_dim: int = 128, pool_out_dim: int = 1, @@ -51,8 +50,9 @@ def __init__(self, weight_decay: float = 1e-4, adam_betas: Tuple[float, float] = (0.9, 0.99), verbose: bool = False, + slide_dataset: Dataset = Dataset(data=[]), tile_size: int = 224, - level: int = 1) -> None: + level: Optional[int] = 1) -> None: """ :param label_column: Label key for input batch dictionary. :param n_classes: Number of output classes for MIL prediction. diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 7a5e91c61..df44cd010 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -93,7 +93,7 @@ def get_pooling_layer(self) -> Type[nn.Module]: def get_slide_dataset(self) -> Dataset: if self.slide_datatype == PandaDataset.__name__: - return Dataset(PandaDataset(root=self.extra_local_dataset_paths[0])) + return Dataset(PandaDataset(root=self.extra_local_dataset_paths[0])) # type: ignore elif self.slide_datatype is None: return Dataset(data=[]) else: diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 8630582ff..8ad831315 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -28,7 +28,6 @@ InnerEyeSSLEncoder, ) from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL -from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset From ec20c300f54fc18906c51e5f629098c1363e9568 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Thu, 13 Jan 2022 18:20:33 +0000 Subject: [PATCH 10/23] mypy error resolved --- InnerEye/ML/Histopathology/models/deepmil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 72967b446..0366dee34 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -324,7 +324,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore fig.savefig(figpath, bbox_inches='tight') if len(self.slide_dataset) > 0: slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, self.slide_dataset))[0] # type: ignore - load_image_dict(slide_dict, level=self.level, margin=0) + load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore slide_image = slide_dict[SlideKey.IMAGE] location_bbox = slide_dict['location'] From 974bc6985af2a1cfebf3ab7288d9ffda0db1b857 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Fri, 14 Jan 2022 18:01:08 +0000 Subject: [PATCH 11/23] address PR comments --- InnerEye/ML/Histopathology/models/deepmil.py | 23 ++++++----- .../ML/Histopathology/utils/heatmap_utils.py | 7 ++-- .../ML/Histopathology/utils/metrics_utils.py | 6 +-- InnerEye/ML/Histopathology/utils/naming.py | 1 + .../histo_configs/classification/BaseMIL.py | 23 +++-------- .../classification/DeepSMILECrck.py | 10 ++--- .../classification/DeepSMILEPanda.py | 40 ++++++++++++++----- 7 files changed, 58 insertions(+), 52 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 0366dee34..5ac5be528 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -9,6 +9,7 @@ import numpy as np from typing import Any, Callable, Dict, Optional, Tuple, List import torch +import matplotlib.pyplot as plt from pytorch_lightning import LightningModule from torch import Tensor, argmax, mode, nn, no_grad, optim, round @@ -315,31 +316,31 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore for i in range(nslides): slide, score, paths, top_attn = report_cases[key][0][i] fig = plot_slide_noxy(slide, score, paths, top_attn, key + '_top', ncols=4) - figpath = Path(key_folder_path, f'{slide}_top.png') - fig.savefig(figpath, bbox_inches='tight') + self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_top.png')) slide, score, paths, bottom_attn = report_cases[key][1][i] fig = plot_slide_noxy(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) - figpath = Path(key_folder_path, f'{slide}_bottom.png') - fig.savefig(figpath, bbox_inches='tight') + self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_bottom.png')) + if len(self.slide_dataset) > 0: slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, self.slide_dataset))[0] # type: ignore load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore slide_image = slide_dict[SlideKey.IMAGE] - location_bbox = slide_dict['location'] + location_bbox = slide_dict[SlideKey.LOCATION] fig = plot_slide(slide_image=slide_image, scale=1.0) - figpath = Path(key_folder_path, f'{slide}_thumbnail.png') - fig.savefig(figpath, bbox_inches='tight') - + self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_thumbnail.png')) fig = plot_heatmap_slide(slide=slide, slide_image=slide_image, results=results, location_bbox=location_bbox, tile_size=self.tile_size, level=slide_dict['level']) - figpath = Path(key_folder_path, f'{slide}_heatmap.png') - fig.savefig(figpath, bbox_inches='tight') + self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_heatmap.png')) print("Plotting histogram ...") fig = plot_scores_hist(results) - fig.savefig(outputs_fig_path / 'hist_scores.png', bbox_inches='tight') + self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png') + + @staticmethod + def save_figure(fig: plt.figure, figpath: Path) -> None: + fig.savefig(figpath, bbox_inches='tight') @staticmethod def normalize_dict_for_df(dict_old: Dict[str, Any], use_gpu: bool) -> Dict: diff --git a/InnerEye/ML/Histopathology/utils/heatmap_utils.py b/InnerEye/ML/Histopathology/utils/heatmap_utils.py index 83763fe9f..74f16ef84 100644 --- a/InnerEye/ML/Histopathology/utils/heatmap_utils.py +++ b/InnerEye/ML/Histopathology/utils/heatmap_utils.py @@ -8,13 +8,13 @@ import matplotlib.collections as collection -def plot_heatmap_selected_tiles(tile_coords: np.array, +def plot_heatmap_selected_tiles(tile_coords: np.ndarray, tile_values: np.ndarray, location_bbox: List[int], tile_size: int, level: int, - ax: Optional[Axes] = None) -> AxesImage: - """Plots a 2D heatmap for selected tiles to overlay on the slide. + ax: Optional[Axes] = None) -> None: + """Plots a 2D heatmap for selected tiles (e.g. tiles in a bag) to overlay on the slide. :param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]). :param tile_values: Scalar values of the tiles (shape: [N]). :param location_bbox: Location of the bounding box of the slide. @@ -46,4 +46,3 @@ def plot_heatmap_selected_tiles(tile_coords: np.array, pc.set_clim([0, 1]) ax.add_collection(pc) plt.colorbar(pc, ax=ax) - return ax diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index 709560a9c..ccecd0825 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -101,7 +101,7 @@ def plot_slide_noxy(slide: str, score: float, paths: List, attn: List, case: str return fig -def plot_slide(slide_image: np.array, scale: float) -> plt.figure: +def plot_slide(slide_image: np.ndarray, scale: float) -> plt.figure: """Plots a slide thumbnail from a given slide image and scale. :param slide_image: Numpy array of the slide image (shape: [3, H, W]). :return: matplotlib figure of the slide thumbnail. @@ -115,11 +115,11 @@ def plot_slide(slide_image: np.array, scale: float) -> plt.figure: return fig -def plot_heatmap_slide(slide: str, slide_image: np.array, results: Dict, location_bbox: List[int], tile_size: int = 224, level: int = 1) -> plt.figure: +def plot_heatmap_slide(slide: str, slide_image: np.ndarray, results: Dict[str, List[Any]], location_bbox: List[int], tile_size: int = 224, level: int = 1) -> plt.figure: """Plots heatmap of selected tiles overlay on a slide. :param slide: slide identifier. :param slide_image: Numpy array of the slide image (shape: [3, H, W]). - :param results: List that contains slide_level dicts. + :param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides. :param tile_size: Size of each tile. Default 224. :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). Default 1. :param location_bbox: Location of the bounding box of the slide. diff --git a/InnerEye/ML/Histopathology/utils/naming.py b/InnerEye/ML/Histopathology/utils/naming.py index 9f1b237df..5b7273ad9 100644 --- a/InnerEye/ML/Histopathology/utils/naming.py +++ b/InnerEye/ML/Histopathology/utils/naming.py @@ -18,6 +18,7 @@ class SlideKey(str, Enum): ORIGIN = 'origin' FOREGROUND_THRESHOLD = 'foreground_threshold' METADATA = 'metadata' + LOCATION = 'location' class TileKey(str, Enum): diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index df44cd010..25dee58f4 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -8,7 +8,7 @@ their datamodules and configure experiment-specific parameters. """ from pathlib import Path -from typing import Type, Optional +from typing import Type import param from torch import nn @@ -22,7 +22,6 @@ from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder, ImageNetEncoder, ImageNetSimCLREncoder, InnerEyeSSLEncoder, TileEncoder) -from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset class BaseMIL(LightningContainer): @@ -49,10 +48,6 @@ class BaseMIL(LightningContainer): "dataset upfront and save it to disk.") # local_dataset (used as data module root_path) is declared in DatasetParams superclass - # slide dataset parameters: - slide_datatype: Optional[str] = param.String(default=None, allow_None=True, doc="Name of the slide dataset class if available.") - level: Optional[int] = param.Integer(doc="Downsampling level (e.g. 0, 1, 2) of the tiles if available.") - @property def cache_dir(self) -> Path: raise NotImplementedError @@ -91,14 +86,6 @@ def get_pooling_layer(self) -> Type[nn.Module]: else: raise ValueError(f"Unsupported pooling type: {self.pooling_type}") - def get_slide_dataset(self) -> Dataset: - if self.slide_datatype == PandaDataset.__name__: - return Dataset(PandaDataset(root=self.extra_local_dataset_paths[0])) # type: ignore - elif self.slide_datatype is None: - return Dataset(data=[]) - else: - raise ValueError(f"Unsupported slide datatype: {self.slide_datatype}") - def create_model(self) -> DeepMILModule: self.data_module = self.get_data_module() # Encoding is done in the datamodule, so here we provide instead a dummy @@ -110,10 +97,10 @@ def create_model(self) -> DeepMILModule: class_weights=self.data_module.class_weights, l_rate=self.l_rate, weight_decay=self.weight_decay, - adam_betas=self.adam_betas, - slide_dataset=self.get_slide_dataset(), - tile_size=self.tile_size, - level=self.level) + adam_betas=self.adam_betas) def get_data_module(self) -> TilesDataModule: raise NotImplementedError + + def get_slide_dataset(self) -> Dataset: + raise NotImplementedError diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py index 7ea1818d4..30b60b52c 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py @@ -14,11 +14,12 @@ damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405 """ from pathlib import Path -from typing import Any, Dict +from typing import Any, List import os from monai.transforms import Compose from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Callback from health_ml.networks.layers.attention_layers import GatedAttentionLayer from health_azure.utils import get_workspace @@ -127,11 +128,8 @@ def get_data_module(self) -> TilesDataModule: cross_validation_split_index=self.cross_validation_split_index, ) - def get_trainer_arguments(self) -> Dict[str, Any]: - # These arguments will be passed through to the Lightning trainer. - kw_args = super().get_trainer_arguments() - kw_args["callbacks"] = self.callbacks - return kw_args + def get_callbacks(self) -> List[Callback]: + return super().get_callbacks() + [self.callbacks] def get_path_to_best_checkpoint(self) -> Path: """ diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 8ad831315..2aa0d613d 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -3,11 +3,13 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Any, Dict, Optional +from typing import Any, Optional, List from pathlib import Path import os from monai.transforms import Compose from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Callback +from monai.data.dataset import Dataset from health_azure.utils import CheckpointDownloader from health_azure.utils import get_workspace @@ -26,12 +28,14 @@ ImageNetEncoder, ImageNetSimCLREncoder, InnerEyeSSLEncoder, + IdentityEncoder ) from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL from InnerEye.ML.Histopathology.datasets.panda_dataset import PandaDataset +from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule -local_mode = False +local_mode = True path_local_data: Optional[Path] if local_mode: path_local_data = Path("/tmp/datasets/PANDA_tiles") @@ -52,8 +56,6 @@ def __init__(self, **kwargs: Any) -> None: default_kwargs = dict( # declared in BaseMIL: pooling_type=GatedAttentionLayer.__name__, - slide_datatype=PandaDataset.__name__, - level=1, # declared in DatasetParams: local_dataset=path_local_data, azure_dataset_id=azure_dataset_id, @@ -131,11 +133,29 @@ def get_data_module(self) -> PandaTilesDataModule: cross_validation_split_index=self.cross_validation_split_index, ) - def get_trainer_arguments(self) -> Dict[str, Any]: - # These arguments will be passed through to the Lightning trainer. - kw_args = super().get_trainer_arguments() - kw_args["callbacks"] = self.callbacks - return kw_args + def create_model(self) -> DeepMILModule: + self.data_module = self.get_data_module() + # Encoding is done in the datamodule, so here we provide instead a dummy + # no-op IdentityEncoder to be used inside the model + self.slide_dataset = self.get_slide_dataset() + self.level = 1 + return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), + label_column=self.data_module.train_dataset.LABEL_COLUMN, + n_classes=self.data_module.train_dataset.N_CLASSES, + pooling_layer=self.get_pooling_layer(), + class_weights=self.data_module.class_weights, + l_rate=self.l_rate, + weight_decay=self.weight_decay, + adam_betas=self.adam_betas, + slide_dataset=self.get_slide_dataset(), + tile_size=self.tile_size, + level=self.level) + + def get_slide_dataset(self) -> Dataset: + return Dataset(PandaDataset(root=self.extra_local_dataset_paths[0])) # type: ignore + + def get_callbacks(self) -> List[Callback]: + return super().get_callbacks() + [self.callbacks] def get_path_to_best_checkpoint(self) -> Path: """ @@ -159,7 +179,7 @@ def get_path_to_best_checkpoint(self) -> Path: if checkpoint_path.is_file(): return checkpoint_path - raise ValueError("Path to best checkpoint not found") + raise ValueError("Path to best checkpoint not found") class PandaImageNetMIL(DeepSMILEPanda): From 4fb31f9d8213187a52d86039e8b808da2f149049 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Fri, 14 Jan 2022 18:27:50 +0000 Subject: [PATCH 12/23] flake8 errors resolved --- InnerEye/ML/Histopathology/utils/heatmap_utils.py | 3 +-- InnerEye/ML/Histopathology/utils/metrics_utils.py | 9 +++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/InnerEye/ML/Histopathology/utils/heatmap_utils.py b/InnerEye/ML/Histopathology/utils/heatmap_utils.py index 74f16ef84..0e9e07b90 100644 --- a/InnerEye/ML/Histopathology/utils/heatmap_utils.py +++ b/InnerEye/ML/Histopathology/utils/heatmap_utils.py @@ -3,7 +3,6 @@ import numpy as np from matplotlib.axes import Axes -from matplotlib.image import AxesImage import matplotlib.patches as patches import matplotlib.collections as collection @@ -14,7 +13,7 @@ def plot_heatmap_selected_tiles(tile_coords: np.ndarray, tile_size: int, level: int, ax: Optional[Axes] = None) -> None: - """Plots a 2D heatmap for selected tiles (e.g. tiles in a bag) to overlay on the slide. + """Plots a 2D heatmap for selected tiles (e.g. tiles in a bag). :param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]). :param tile_values: Scalar values of the tiles (shape: [N]). :param location_bbox: Location of the bounding box of the slide. diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index ccecd0825..e10cd1d73 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -115,8 +115,13 @@ def plot_slide(slide_image: np.ndarray, scale: float) -> plt.figure: return fig -def plot_heatmap_slide(slide: str, slide_image: np.ndarray, results: Dict[str, List[Any]], location_bbox: List[int], tile_size: int = 224, level: int = 1) -> plt.figure: - """Plots heatmap of selected tiles overlay on a slide. +def plot_heatmap_slide(slide: str, + slide_image: np.ndarray, + results: Dict[str, List[Any]], + location_bbox: List[int], + tile_size: int = 224, + level: int = 1) -> plt.figure: + """Plots heatmap of selected tiles (e.g. tiles in a bag) overlay on the corresponding slide. :param slide: slide identifier. :param slide_image: Numpy array of the slide image (shape: [3, H, W]). :param results: Dict containing ResultsKey keys (e.g. slide id) and values as lists of output slides. From e6b5f5870cfb74f8cbe4ba0b88cdb804e71ff4d9 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Sun, 16 Jan 2022 14:53:59 +0000 Subject: [PATCH 13/23] PR comments addressed --- InnerEye/ML/Histopathology/models/deepmil.py | 5 +++-- .../configs/histo_configs/classification/DeepSMILEPanda.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 5ac5be528..b941c3451 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, List import torch import matplotlib.pyplot as plt +import more_itertools as mi from pytorch_lightning import LightningModule from torch import Tensor, argmax, mode, nn, no_grad, optim, round @@ -323,8 +324,8 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_bottom.png')) if len(self.slide_dataset) > 0: - slide_dict = list(filter(lambda entry: entry[SlideKey.SLIDE_ID] == slide, self.slide_dataset))[0] # type: ignore - load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore + slide_dict = mi.first_true(self.slide_dataset, pred=lambda entry: entry[SlideKey.SLIDE_ID] == slide) # type: ignore + load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore slide_image = slide_dict[SlideKey.IMAGE] location_bbox = slide_dict[SlideKey.LOCATION] diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 2aa0d613d..9bb74e4a0 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -35,7 +35,7 @@ from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule -local_mode = True +local_mode = False path_local_data: Optional[Path] if local_mode: path_local_data = Path("/tmp/datasets/PANDA_tiles") From 5d5bc769e1c77043326a93e4e4a2870c7d446349 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Sun, 16 Jan 2022 16:15:31 +0000 Subject: [PATCH 14/23] add test for plots --- .../utils/test_metrics_utils.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index a63884477..355690a3a 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -4,12 +4,14 @@ # ------------------------------------------------------------------------------------------ import math +import numpy as np from typing import List import matplotlib from torch.functional import Tensor +import pytest -from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles +from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, plot_heatmap_slide from InnerEye.ML.Histopathology.utils.naming import ResultsKey @@ -60,3 +62,31 @@ def test_select_k_tiles() -> None: def test_plot_scores_hist() -> None: fig = plot_scores_hist(test_dict) assert isinstance(fig, matplotlib.figure.Figure) + +@pytest.mark.parametrize("scale", [0.1, 1.2, 2.4, 3.6]) +def test_plot_slide(scale: int) -> None: + slide_image = np.random.rand(3, 1000, 2000) + fig = plot_slide(slide_image=slide_image, scale=scale) + assert isinstance(fig, matplotlib.figure.Figure) + +@pytest.mark.parametrize("level", [0, 1, 2]) +def test_plot_heatmap_slide(level: int) -> None: + slide_image = np.random.rand(3, 1000, 2000) + location_bbox = [100, 100] + slide = 1 + fig = plot_heatmap_slide(slide, slide_image, test_dict, location_bbox) + assert isinstance(fig, matplotlib.figure.Figure) + + tile_coords = np.array([[100, 100], [200, 100], [200, 200]]) + level_dict = {"0": 1, "1": 4, "2": 16} + factor = level_dict[str(level)] + x_tr, y_tr = location_bbox + tile_xs, tile_ys = tile_coords.T + tile_xs = tile_xs - x_tr + tile_ys = tile_ys - y_tr + tile_xs = tile_xs//factor + tile_ys = tile_ys//factor + assert min(tile_xs) >= 0 + assert max(tile_xs) <= slide_image.shape[1]//factor + assert min(tile_ys) >= 0 + assert max(tile_ys) <= slide_image.shape[2]//factor From fd0a1f2e60066a736966beab75b1495e0995a820 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Mon, 17 Jan 2022 12:19:08 +0000 Subject: [PATCH 15/23] PR comments more --- InnerEye/ML/Histopathology/models/deepmil.py | 4 +-- .../ML/Histopathology/utils/heatmap_utils.py | 10 +++--- .../ML/Histopathology/utils/metrics_utils.py | 13 ++++---- .../classification/DeepSMILEPanda.py | 32 ++++++------------- 4 files changed, 23 insertions(+), 36 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index b941c3451..18669f09b 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -19,7 +19,7 @@ from InnerEye.Common import fixed_paths from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset from InnerEye.ML.Histopathology.models.encoders import TileEncoder -from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist, plot_heatmap_slide, plot_slide +from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist, plot_heatmap_overlay, plot_slide from InnerEye.ML.Histopathology.utils.naming import ResultsKey from monai.data.dataset import Dataset @@ -331,7 +331,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore fig = plot_slide(slide_image=slide_image, scale=1.0) self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_thumbnail.png')) - fig = plot_heatmap_slide(slide=slide, slide_image=slide_image, results=results, + fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results, location_bbox=location_bbox, tile_size=self.tile_size, level=slide_dict['level']) self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_heatmap.png')) diff --git a/InnerEye/ML/Histopathology/utils/heatmap_utils.py b/InnerEye/ML/Histopathology/utils/heatmap_utils.py index 0e9e07b90..a0a22f6de 100644 --- a/InnerEye/ML/Histopathology/utils/heatmap_utils.py +++ b/InnerEye/ML/Histopathology/utils/heatmap_utils.py @@ -12,20 +12,21 @@ def plot_heatmap_selected_tiles(tile_coords: np.ndarray, location_bbox: List[int], tile_size: int, level: int, - ax: Optional[Axes] = None) -> None: + ax: Optional[Axes] = None) -> np.ndarray: """Plots a 2D heatmap for selected tiles (e.g. tiles in a bag). :param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]). :param tile_values: Scalar values of the tiles (shape: [N]). :param location_bbox: Location of the bounding box of the slide. - :param level: Magnification at which tiles are available (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). + :param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available + (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). :param tile_size: Size of each tile. :param ax: Axes onto which to plot the heatmap (default: current axes). """ if ax is None: ax = plt.gca() - level_dict = {"0": 1, "1": 4, "2": 16} - factor = level_dict[str(level)] + level_dict = {0: 1, 1: 4, 2: 16} + factor = level_dict[level] x_tr, y_tr = location_bbox tile_xs, tile_ys = tile_coords.T tile_xs = tile_xs - x_tr @@ -45,3 +46,4 @@ def plot_heatmap_selected_tiles(tile_coords: np.ndarray, pc.set_clim([0, 1]) ax.add_collection(pc) plt.colorbar(pc, ax=ax) + return sel_coords diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index e10cd1d73..b9f3808f7 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -115,7 +115,7 @@ def plot_slide(slide_image: np.ndarray, scale: float) -> plt.figure: return fig -def plot_heatmap_slide(slide: str, +def plot_heatmap_overlay(slide: str, slide_image: np.ndarray, results: Dict[str, List[Any]], location_bbox: List[int], @@ -132,8 +132,11 @@ def plot_heatmap_slide(slide: str, """ fig, ax = plt.subplots() slide_image = slide_image.transpose(1, 2, 0) - coords = [] + ax.imshow(slide_image) + ax.set_xlim(0, slide_image.shape[1]) + ax.set_ylim(slide_image.shape[0], 0) + coords = [] slide_ids = [item[0] for item in results[ResultsKey.SLIDE_ID]] slide_idx = slide_ids.index(slide) attentions = results[ResultsKey.BAG_ATTN][slide_idx] @@ -146,9 +149,5 @@ def plot_heatmap_slide(slide: str, coords = np.array(coords) attentions = np.array(attentions.cpu()).reshape(-1) - ax.imshow(slide_image) - ax.set_xlim(0, slide_image.shape[1]) - ax.set_ylim(slide_image.shape[0], 0) - - plot_heatmap_selected_tiles(tile_coords=coords, tile_values=attentions, location_bbox=location_bbox, tile_size=tile_size, level=level) + _ = plot_heatmap_selected_tiles(tile_coords=coords, tile_values=attentions, location_bbox=location_bbox, tile_size=tile_size, level=level) return fig diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 9bb74e4a0..74b3a1534 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -3,7 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Any, Optional, List +from typing import Any, List from pathlib import Path import os from monai.transforms import Compose @@ -12,7 +12,7 @@ from monai.data.dataset import Dataset from health_azure.utils import CheckpointDownloader -from health_azure.utils import get_workspace +from health_azure.utils import get_workspace, is_running_in_azure_ml from health_ml.networks.layers.attention_layers import GatedAttentionLayer from InnerEye.Common import fixed_paths from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule @@ -35,35 +35,19 @@ from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule -local_mode = False -path_local_data: Optional[Path] -if local_mode: - path_local_data = Path("/tmp/datasets/PANDA_tiles") - azure_dataset_id = "Dummy" - extra_local_dataset_paths = [Path("/tmp/datasets/PANDA")] - extra_azure_dataset_ids = ["Dummy"] - num_epochs = 1 -else: - path_local_data = None - azure_dataset_id = "PANDA_tiles" - extra_local_dataset_paths = [] - extra_azure_dataset_ids = ["PANDA"] - num_epochs = 100 - - class DeepSMILEPanda(BaseMIL): def __init__(self, **kwargs: Any) -> None: default_kwargs = dict( # declared in BaseMIL: pooling_type=GatedAttentionLayer.__name__, # declared in DatasetParams: - local_dataset=path_local_data, - azure_dataset_id=azure_dataset_id, - extra_azure_dataset_ids=extra_azure_dataset_ids, - extra_local_dataset_paths=extra_local_dataset_paths, + local_dataset=Path("/tmp/datasets/PANDA_tiles"), + azure_dataset_id="PANDA_tiles", + extra_azure_dataset_ids=["PANDA"], + extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")], # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: - num_epochs=num_epochs, + num_epochs=100, recovery_checkpoint_save_interval=10, recovery_checkpoints_save_last_k=-1, # use_mixed_precision = True, @@ -77,6 +61,8 @@ def __init__(self, **kwargs: Any) -> None: default_kwargs.update(kwargs) super().__init__(**default_kwargs) super().__init__(**default_kwargs) + if not is_running_in_azure_ml(): + self.num_epochs = 1 self.best_checkpoint_filename = "checkpoint_max_val_auroc" self.best_checkpoint_filename_with_suffix = ( self.best_checkpoint_filename + ".ckpt" From a60265d7da975d57f74e22aec6ff90283ece4905 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Mon, 17 Jan 2022 14:29:23 +0000 Subject: [PATCH 16/23] test for heatmap --- .../utils/test_metrics_utils.py | 67 ++++++++++++++----- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index 355690a3a..8d57f27d1 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -11,9 +11,9 @@ from torch.functional import Tensor import pytest -from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, plot_heatmap_slide +from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, plot_heatmap_overlay from InnerEye.ML.Histopathology.utils.naming import ResultsKey - +from InnerEye.ML.Histopathology.utils.heatmap_utils import plot_heatmap_selected_tiles def assert_equal_lists(pred: List, expected: List) -> None: assert len(pred) == len(expected) @@ -39,7 +39,17 @@ def assert_equal_lists(pred: List, expected: List) -> None: [Tensor([[0.1, 0.0, 0.2, 0.15]]), Tensor([[0.10, 0.18, 0.15, 0.13]]), Tensor([[0.25, 0.23, 0.20, 0.21]]), - Tensor([[0.33, 0.31, 0.37, 0.35]])] + Tensor([[0.33, 0.31, 0.37, 0.35]])], + ResultsKey.TILE_X: + [Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424])], + ResultsKey.TILE_Y: + [Tensor([200, 424, 200, 424]), + Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424]), + Tensor([200, 200, 424, 424])] } def test_select_k_tiles() -> None: @@ -63,29 +73,54 @@ def test_plot_scores_hist() -> None: fig = plot_scores_hist(test_dict) assert isinstance(fig, matplotlib.figure.Figure) + @pytest.mark.parametrize("scale", [0.1, 1.2, 2.4, 3.6]) def test_plot_slide(scale: int) -> None: slide_image = np.random.rand(3, 1000, 2000) fig = plot_slide(slide_image=slide_image, scale=scale) assert isinstance(fig, matplotlib.figure.Figure) -@pytest.mark.parametrize("level", [0, 1, 2]) -def test_plot_heatmap_slide(level: int) -> None: + +def test_plot_heatmap_overlay() -> None: slide_image = np.random.rand(3, 1000, 2000) location_bbox = [100, 100] - slide = 1 - fig = plot_heatmap_slide(slide, slide_image, test_dict, location_bbox) + slide = 1 + tile_size = 224 + level = 1 + fig = plot_heatmap_overlay(slide=slide, # type: ignore + slide_image=slide_image, + results=test_dict, # type: ignore + location_bbox=location_bbox, + tile_size=tile_size, + level=level) assert isinstance(fig, matplotlib.figure.Figure) - tile_coords = np.array([[100, 100], [200, 100], [200, 200]]) - level_dict = {"0": 1, "1": 4, "2": 16} - factor = level_dict[str(level)] - x_tr, y_tr = location_bbox - tile_xs, tile_ys = tile_coords.T - tile_xs = tile_xs - x_tr - tile_ys = tile_ys - y_tr - tile_xs = tile_xs//factor - tile_ys = tile_ys//factor +@pytest.mark.parametrize("level", [0, 1, 2]) +def test_plot_heatmap_selected_tiles(level: int) -> None: + slide = 1 + tile_size = 224 + location_bbox = [100, 100] + slide_image = np.random.rand(3, 1000, 2000) + + coords = [] + slide_ids = [item[0] for item in test_dict[ResultsKey.SLIDE_ID]] # type: ignore + slide_idx = slide_ids.index(slide) + attentions = test_dict[ResultsKey.BAG_ATTN][slide_idx] # type: ignore + for tile_idx in range(len(test_dict[ResultsKey.IMAGE_PATH][slide_idx])): # type: ignore + tile_coords = np.transpose(np.array([test_dict[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(), # type: ignore + test_dict[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()])) # type: ignore + coords.append(tile_coords) + + coords = np.array(coords) + attentions = np.array(attentions.cpu()).reshape(-1) + tile_coords_transformed = plot_heatmap_selected_tiles(tile_coords=coords, + tile_values=attentions, + location_bbox=location_bbox, + tile_size=tile_size, + level=level) + tile_xs, tile_ys = tile_coords_transformed.T + level_dict = {0: 1, 1: 4, 2: 16} + factor = level_dict[level] assert min(tile_xs) >= 0 assert max(tile_xs) <= slide_image.shape[1]//factor assert min(tile_ys) >= 0 From 46e3a3ad9c05d039cc7f600122a9b7e6bd0f1f11 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Mon, 17 Jan 2022 14:40:11 +0000 Subject: [PATCH 17/23] tests for heatmap --- Tests/ML/histopathology/utils/test_metrics_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index 8d57f27d1..e33de33df 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -122,6 +122,6 @@ def test_plot_heatmap_selected_tiles(level: int) -> None: level_dict = {0: 1, 1: 4, 2: 16} factor = level_dict[level] assert min(tile_xs) >= 0 - assert max(tile_xs) <= slide_image.shape[1]//factor + assert max(tile_xs) <= slide_image.shape[2]//factor assert min(tile_ys) >= 0 - assert max(tile_ys) <= slide_image.shape[2]//factor + assert max(tile_ys) <= slide_image.shape[1]//factor From e25b41e7e152d0557e8287975350326c62ccf6a3 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Mon, 17 Jan 2022 16:13:33 +0000 Subject: [PATCH 18/23] add file comparison plot tests --- .../ML/Histopathology/utils/heatmap_utils.py | 41 ++++---------- .../ML/Histopathology/utils/metrics_utils.py | 18 ++++++- .../utils/test_metrics_utils.py | 53 ++++++++++++++----- .../histo_heatmaps/heatmap_overlay.png | 3 ++ .../test_data/histo_heatmaps/score_hist.png | 3 ++ .../ML/test_data/histo_heatmaps/slide_0.1.png | 3 ++ .../ML/test_data/histo_heatmaps/slide_1.2.png | 3 ++ .../ML/test_data/histo_heatmaps/slide_2.4.png | 3 ++ .../ML/test_data/histo_heatmaps/slide_3.6.png | 3 ++ 9 files changed, 83 insertions(+), 47 deletions(-) create mode 100644 Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png create mode 100644 Tests/ML/test_data/histo_heatmaps/score_hist.png create mode 100644 Tests/ML/test_data/histo_heatmaps/slide_0.1.png create mode 100644 Tests/ML/test_data/histo_heatmaps/slide_1.2.png create mode 100644 Tests/ML/test_data/histo_heatmaps/slide_2.4.png create mode 100644 Tests/ML/test_data/histo_heatmaps/slide_3.6.png diff --git a/InnerEye/ML/Histopathology/utils/heatmap_utils.py b/InnerEye/ML/Histopathology/utils/heatmap_utils.py index 5c331ec0c..bbd74f75d 100644 --- a/InnerEye/ML/Histopathology/utils/heatmap_utils.py +++ b/InnerEye/ML/Histopathology/utils/heatmap_utils.py @@ -3,35 +3,22 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Optional, List -import matplotlib.pyplot as plt +from typing import List import numpy as np -from matplotlib.axes import Axes -import matplotlib.patches as patches -import matplotlib.collections as collection - -def plot_heatmap_selected_tiles(tile_coords: np.ndarray, - tile_values: np.ndarray, - location_bbox: List[int], - tile_size: int, - level: int, - ax: Optional[Axes] = None) -> np.ndarray: - """Plots a 2D heatmap for selected tiles (e.g. tiles in a bag). - :param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]). - :param tile_values: Scalar values of the tiles (shape: [N]). - :param location_bbox: Location of the bounding box of the slide. - :param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available +def location_selected_tiles(tile_coords: np.ndarray, + location_bbox: List[int], + level: int) -> np.ndarray: + """ Return the scaled and shifted tile co-ordinates for selected tiles in the slide. + :param tile_coords: XY tile coordinates, assumed to be spaced by multiples of `tile_size` (shape: [N, 2]) in original resolution. + :param location_bbox: Location of the bounding box on the slide in original resolution. + :param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available. (e.g. PANDA levels are 0 for original, 1 for 4x downsampled, 2 for 16x downsampled). - :param tile_size: Size of each tile. - :param ax: Axes onto which to plot the heatmap (default: current axes). """ - if ax is None: - ax = plt.gca() - level_dict = {0: 1, 1: 4, 2: 16} factor = level_dict[level] + x_tr, y_tr = location_bbox tile_xs, tile_ys = tile_coords.T tile_xs = tile_xs - x_tr @@ -39,16 +26,6 @@ def plot_heatmap_selected_tiles(tile_coords: np.ndarray, tile_xs = tile_xs//factor tile_ys = tile_ys//factor - cmap = plt.cm.get_cmap('jet') sel_coords = np.transpose([tile_xs.tolist(), tile_ys.tolist()]) - rects = [] - for i in range(sel_coords.shape[0]): - rect = patches.Rectangle((sel_coords[i][0], sel_coords[i][1]), tile_size, tile_size) - rects.append(rect) - pc = collection.PatchCollection(rects, match_original=True, cmap=cmap, alpha=.5, edgecolor=None) - pc.set_array(np.array(tile_values)) - pc.set_clim([0, 1]) - ax.add_collection(pc) - plt.colorbar(pc, ax=ax) return sel_coords diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index b9f3808f7..9e3702a43 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -8,10 +8,12 @@ import matplotlib.pyplot as plt from math import ceil import numpy as np +import matplotlib.patches as patches +import matplotlib.collections as collection from InnerEye.ML.Histopathology.models.transforms import load_pil_image from InnerEye.ML.Histopathology.utils.naming import ResultsKey -from InnerEye.ML.Histopathology.utils.heatmap_utils import plot_heatmap_selected_tiles +from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1, @@ -149,5 +151,17 @@ def plot_heatmap_overlay(slide: str, coords = np.array(coords) attentions = np.array(attentions.cpu()).reshape(-1) - _ = plot_heatmap_selected_tiles(tile_coords=coords, tile_values=attentions, location_bbox=location_bbox, tile_size=tile_size, level=level) + + sel_coords = location_selected_tiles(tile_coords=coords, location_bbox=location_bbox, level=level) + cmap = plt.cm.get_cmap('jet') + rects = [] + for i in range(sel_coords.shape[0]): + rect = patches.Rectangle((sel_coords[i][0], sel_coords[i][1]), tile_size, tile_size) + rects.append(rect) + + pc = collection.PatchCollection(rects, match_original=True, cmap=cmap, alpha=.5, edgecolor=None) + pc.set_array(np.array(attentions)) + pc.set_clim([0, 1]) + ax.add_collection(pc) + plt.colorbar(pc, ax=ax) return fig diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index e33de33df..fe4ee4d2c 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -3,6 +3,8 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +from pathlib import Path + import math import numpy as np from typing import List @@ -13,7 +15,13 @@ from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, plot_heatmap_overlay from InnerEye.ML.Histopathology.utils.naming import ResultsKey -from InnerEye.ML.Histopathology.utils.heatmap_utils import plot_heatmap_selected_tiles +from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles +from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path +from InnerEye.Common.output_directories import OutputFolderForTests +from InnerEye.ML.plotting import resize_and_save +from InnerEye.ML.utils.ml_util import set_random_seed +from Tests.ML.util import assert_binary_files_match + def assert_equal_lists(pred: List, expected: List) -> None: assert len(pred) == len(expected) @@ -36,7 +44,7 @@ def assert_equal_lists(pred: List, expected: List) -> None: ResultsKey.PROB: [Tensor([0.5]), Tensor([0.7]), Tensor([0.4]), Tensor([1.0])], ResultsKey.TRUE_LABEL: [0, 1, 1, 1], ResultsKey.BAG_ATTN: - [Tensor([[0.1, 0.0, 0.2, 0.15]]), + [Tensor([[0.1, 0.3, 0.5, 0.8]]), Tensor([[0.10, 0.18, 0.15, 0.13]]), Tensor([[0.25, 0.23, 0.20, 0.21]]), Tensor([[0.33, 0.31, 0.37, 0.35]])], @@ -69,24 +77,40 @@ def test_select_k_tiles() -> None: assert_equal_lists(bottom_tp, [(4, 1.0, [2, 1], [Tensor([0.31]), Tensor([0.33])]), (2, 0.7, [1, 4], [Tensor([0.10]), Tensor([0.13])])]) -def test_plot_scores_hist() -> None: +def test_plot_scores_hist(test_output_dirs: OutputFolderForTests) -> None: fig = plot_scores_hist(test_dict) assert isinstance(fig, matplotlib.figure.Figure) + file = Path(test_output_dirs.root_dir) / "plot_score_hist.png" + resize_and_save(5, 5, file) + assert file.exists() + expected = full_ml_test_data_path("histo_heatmaps") / "score_hist.png" + # To update the stored results, uncomment this line: + # expected.write_bytes(file.read_bytes()) + assert_binary_files_match(file, expected) @pytest.mark.parametrize("scale", [0.1, 1.2, 2.4, 3.6]) -def test_plot_slide(scale: int) -> None: +def test_plot_slide(test_output_dirs: OutputFolderForTests, scale: int) -> None: + set_random_seed(0) slide_image = np.random.rand(3, 1000, 2000) fig = plot_slide(slide_image=slide_image, scale=scale) assert isinstance(fig, matplotlib.figure.Figure) + file = Path(test_output_dirs.root_dir) / "plot_slide.png" + resize_and_save(5, 5, file) + assert file.exists() + expected = full_ml_test_data_path("histo_heatmaps") / f"slide_{scale}.png" + # To update the stored results, uncomment this line: + # expected.write_bytes(file.read_bytes()) + assert_binary_files_match(file, expected) -def test_plot_heatmap_overlay() -> None: +def test_plot_heatmap_overlay(test_output_dirs: OutputFolderForTests) -> None: + set_random_seed(0) slide_image = np.random.rand(3, 1000, 2000) location_bbox = [100, 100] slide = 1 tile_size = 224 - level = 1 + level = 0 fig = plot_heatmap_overlay(slide=slide, # type: ignore slide_image=slide_image, results=test_dict, # type: ignore @@ -94,29 +118,32 @@ def test_plot_heatmap_overlay() -> None: tile_size=tile_size, level=level) assert isinstance(fig, matplotlib.figure.Figure) + file = Path(test_output_dirs.root_dir) / "plot_heatmap_overlay.png" + resize_and_save(5, 5, file) + assert file.exists() + expected = full_ml_test_data_path("histo_heatmaps") / "heatmap_overlay.png" + # To update the stored results, uncomment this line: + # expected.write_bytes(file.read_bytes()) + assert_binary_files_match(file, expected) @pytest.mark.parametrize("level", [0, 1, 2]) -def test_plot_heatmap_selected_tiles(level: int) -> None: +def test_location_selected_tiles(level: int) -> None: + set_random_seed(0) slide = 1 - tile_size = 224 location_bbox = [100, 100] slide_image = np.random.rand(3, 1000, 2000) coords = [] slide_ids = [item[0] for item in test_dict[ResultsKey.SLIDE_ID]] # type: ignore slide_idx = slide_ids.index(slide) - attentions = test_dict[ResultsKey.BAG_ATTN][slide_idx] # type: ignore for tile_idx in range(len(test_dict[ResultsKey.IMAGE_PATH][slide_idx])): # type: ignore tile_coords = np.transpose(np.array([test_dict[ResultsKey.TILE_X][slide_idx][tile_idx].cpu().numpy(), # type: ignore test_dict[ResultsKey.TILE_Y][slide_idx][tile_idx].cpu().numpy()])) # type: ignore coords.append(tile_coords) coords = np.array(coords) - attentions = np.array(attentions.cpu()).reshape(-1) - tile_coords_transformed = plot_heatmap_selected_tiles(tile_coords=coords, - tile_values=attentions, + tile_coords_transformed = location_selected_tiles(tile_coords=coords, location_bbox=location_bbox, - tile_size=tile_size, level=level) tile_xs, tile_ys = tile_coords_transformed.T level_dict = {0: 1, 1: 4, 2: 16} diff --git a/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png new file mode 100644 index 000000000..74c17e304 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba2c7ff3cd82656aa876e853a009f3157ef121a63d838feaf18c7e932a7a363c +size 314745 diff --git a/Tests/ML/test_data/histo_heatmaps/score_hist.png b/Tests/ML/test_data/histo_heatmaps/score_hist.png new file mode 100644 index 000000000..bced47d8b --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/score_hist.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca95c0017d0a51d75d118e54f21c1e907b3d90dcca822b23622e369267907198 +size 17057 diff --git a/Tests/ML/test_data/histo_heatmaps/slide_0.1.png b/Tests/ML/test_data/histo_heatmaps/slide_0.1.png new file mode 100644 index 000000000..964f81df6 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/slide_0.1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6867218d67f06f482f1ce4a5c102fee662e445b34c7d09121b2aba8380eb54a +size 474111 diff --git a/Tests/ML/test_data/histo_heatmaps/slide_1.2.png b/Tests/ML/test_data/histo_heatmaps/slide_1.2.png new file mode 100644 index 000000000..964f81df6 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/slide_1.2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6867218d67f06f482f1ce4a5c102fee662e445b34c7d09121b2aba8380eb54a +size 474111 diff --git a/Tests/ML/test_data/histo_heatmaps/slide_2.4.png b/Tests/ML/test_data/histo_heatmaps/slide_2.4.png new file mode 100644 index 000000000..964f81df6 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/slide_2.4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6867218d67f06f482f1ce4a5c102fee662e445b34c7d09121b2aba8380eb54a +size 474111 diff --git a/Tests/ML/test_data/histo_heatmaps/slide_3.6.png b/Tests/ML/test_data/histo_heatmaps/slide_3.6.png new file mode 100644 index 000000000..964f81df6 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/slide_3.6.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6867218d67f06f482f1ce4a5c102fee662e445b34c7d09121b2aba8380eb54a +size 474111 From f33cd657e138136a18854a344ef936505602ea95 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Tue, 18 Jan 2022 09:13:54 +0000 Subject: [PATCH 19/23] reverting test_dict values for other test with hardcoded values --- Tests/ML/histopathology/utils/test_metrics_utils.py | 2 +- Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index fe4ee4d2c..c41c70280 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -44,7 +44,7 @@ def assert_equal_lists(pred: List, expected: List) -> None: ResultsKey.PROB: [Tensor([0.5]), Tensor([0.7]), Tensor([0.4]), Tensor([1.0])], ResultsKey.TRUE_LABEL: [0, 1, 1, 1], ResultsKey.BAG_ATTN: - [Tensor([[0.1, 0.3, 0.5, 0.8]]), + [Tensor([[0.1, 0.0, 0.2, 0.15]]), Tensor([[0.10, 0.18, 0.15, 0.13]]), Tensor([[0.25, 0.23, 0.20, 0.21]]), Tensor([[0.33, 0.31, 0.37, 0.35]])], diff --git a/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png index 74c17e304..d1177cb8d 100644 --- a/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png +++ b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ba2c7ff3cd82656aa876e853a009f3157ef121a63d838feaf18c7e932a7a363c -size 314745 +oid sha256:9ada934edff6f48d5a4ef6b863e4250dbc0f4dcabdf212873b59584b179647fc +size 314721 From 6899c912b036af0d48d7807e7d720db66c6af257 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Tue, 18 Jan 2022 10:23:27 +0000 Subject: [PATCH 20/23] PR comments Valentina --- InnerEye/ML/Histopathology/models/deepmil.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 18669f09b..ba65c9084 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -325,14 +325,14 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore if len(self.slide_dataset) > 0: slide_dict = mi.first_true(self.slide_dataset, pred=lambda entry: entry[SlideKey.SLIDE_ID] == slide) # type: ignore - load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore + _ = load_image_dict(slide_dict, level=self.level, margin=0) slide_image = slide_dict[SlideKey.IMAGE] location_bbox = slide_dict[SlideKey.LOCATION] fig = plot_slide(slide_image=slide_image, scale=1.0) self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_thumbnail.png')) fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results, - location_bbox=location_bbox, tile_size=self.tile_size, level=slide_dict['level']) + location_bbox=location_bbox, tile_size=self.tile_size, level=self.level) self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_heatmap.png')) print("Plotting histogram ...") From 89de3041430cd29d9928bc3daa7014baa74b19d2 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Tue, 18 Jan 2022 12:56:41 +0000 Subject: [PATCH 21/23] PR comments --- InnerEye/ML/Histopathology/models/deepmil.py | 17 ++++++++--------- .../ML/Histopathology/utils/metrics_utils.py | 6 ++++-- .../histo_configs/classification/BaseMIL.py | 4 ++-- .../classification/DeepSMILEPanda.py | 5 ++--- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index ba65c9084..827d23a03 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -17,12 +17,11 @@ from torchmetrics import AUROC, F1, Accuracy, Precision, Recall from InnerEye.Common import fixed_paths -from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset +from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset, SlidesDataset from InnerEye.ML.Histopathology.models.encoders import TileEncoder -from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_slide_noxy, plot_scores_hist, plot_heatmap_overlay, plot_slide +from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_attention_tiles, plot_scores_hist, plot_heatmap_overlay, plot_slide from InnerEye.ML.Histopathology.utils.naming import ResultsKey -from monai.data.dataset import Dataset from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict from InnerEye.ML.Histopathology.utils.naming import SlideKey @@ -52,9 +51,9 @@ def __init__(self, weight_decay: float = 1e-4, adam_betas: Tuple[float, float] = (0.9, 0.99), verbose: bool = False, - slide_dataset: Dataset = Dataset(data=[]), + slide_dataset: SlidesDataset = None, tile_size: int = 224, - level: Optional[int] = 1) -> None: + level: int = 1) -> None: """ :param label_column: Label key for input batch dictionary. :param n_classes: Number of output classes for MIL prediction. @@ -316,16 +315,16 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore nslides = len(report_cases[key][0]) for i in range(nslides): slide, score, paths, top_attn = report_cases[key][0][i] - fig = plot_slide_noxy(slide, score, paths, top_attn, key + '_top', ncols=4) + fig = plot_attention_tiles(slide, score, paths, top_attn, key + '_top', ncols=4) self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_top.png')) slide, score, paths, bottom_attn = report_cases[key][1][i] - fig = plot_slide_noxy(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) + fig = plot_attention_tiles(slide, score, paths, bottom_attn, key + '_bottom', ncols=4) self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_bottom.png')) - if len(self.slide_dataset) > 0: + if self.slide_dataset is not None: slide_dict = mi.first_true(self.slide_dataset, pred=lambda entry: entry[SlideKey.SLIDE_ID] == slide) # type: ignore - _ = load_image_dict(slide_dict, level=self.level, margin=0) + _ = load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore slide_image = slide_dict[SlideKey.IMAGE] location_bbox = slide_dict[SlideKey.LOCATION] diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index 9e3702a43..4389863ed 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -79,7 +79,7 @@ def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.PROB, return fig -def plot_slide_noxy(slide: str, score: float, paths: List, attn: List, case: str, ncols: int = 5, +def plot_attention_tiles(slide: str, score: float, paths: List, attn: List, case: str, ncols: int = 5, size: Tuple = (10, 10)) -> plt.figure: """ :param slide: slide identifier @@ -154,7 +154,9 @@ def plot_heatmap_overlay(slide: str, sel_coords = location_selected_tiles(tile_coords=coords, location_bbox=location_bbox, level=level) cmap = plt.cm.get_cmap('jet') - rects = [] + + tile_xs, tile_ys = sel_coords.T + rects = [patches.Rectangle(xy, tile_size, tile_size) for xy in zip(tile_xs, tile_ys)] for i in range(sel_coords.shape[0]): rect = patches.Rectangle((sel_coords[i][0], sel_coords[i][1]), tile_size, tile_size) rects.append(rect) diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 25dee58f4..521711b17 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -13,10 +13,10 @@ import param from torch import nn from torchvision.models.resnet import resnet18 -from monai.data.dataset import Dataset from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer from InnerEye.ML.lightning_container import LightningContainer +from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, TilesDataModule from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder, @@ -102,5 +102,5 @@ def create_model(self) -> DeepMILModule: def get_data_module(self) -> TilesDataModule: raise NotImplementedError - def get_slide_dataset(self) -> Dataset: + def get_slide_dataset(self) -> SlidesDataset: raise NotImplementedError diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 4e5b1a130..59a617b74 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -9,7 +9,6 @@ from monai.transforms import Compose from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks import Callback -from monai.data.dataset import Dataset from health_azure.utils import CheckpointDownloader from health_azure.utils import get_workspace, is_running_in_azure_ml @@ -135,8 +134,8 @@ def create_model(self) -> DeepMILModule: tile_size=self.tile_size, level=self.level) - def get_slide_dataset(self) -> Dataset: - return Dataset(PandaDataset(root=self.extra_local_dataset_paths[0])) # type: ignore + def get_slide_dataset(self) -> PandaDataset: + return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore def get_callbacks(self) -> List[Callback]: return super().get_callbacks() + [self.callbacks] From 4de59b12306897b24aa1641ffff7e985d3c6b29d Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Tue, 18 Jan 2022 14:18:18 +0000 Subject: [PATCH 22/23] change colormap to reds --- InnerEye/ML/Histopathology/utils/metrics_utils.py | 2 +- Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index 4389863ed..c3d91d3fc 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -153,7 +153,7 @@ def plot_heatmap_overlay(slide: str, attentions = np.array(attentions.cpu()).reshape(-1) sel_coords = location_selected_tiles(tile_coords=coords, location_bbox=location_bbox, level=level) - cmap = plt.cm.get_cmap('jet') + cmap = plt.cm.get_cmap('Reds') tile_xs, tile_ys = sel_coords.T rects = [patches.Rectangle(xy, tile_size, tile_size) for xy in zip(tile_xs, tile_ys)] diff --git a/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png index d1177cb8d..920af44c8 100644 --- a/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png +++ b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9ada934edff6f48d5a4ef6b863e4250dbc0f4dcabdf212873b59584b179647fc -size 314721 +oid sha256:6b62a4a702209d7e865d393674ddb8d3c032484fc6138d68b6fbd837e084c5af +size 312150 From 7b4094c713e5c250de3df993555fa8329761ccf6 Mon Sep 17 00:00:00 2001 From: t-hsharma Date: Tue, 18 Jan 2022 14:45:20 +0000 Subject: [PATCH 23/23] remove redundant for loop --- InnerEye/ML/Histopathology/utils/metrics_utils.py | 3 --- Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index c3d91d3fc..b6f49f081 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -157,9 +157,6 @@ def plot_heatmap_overlay(slide: str, tile_xs, tile_ys = sel_coords.T rects = [patches.Rectangle(xy, tile_size, tile_size) for xy in zip(tile_xs, tile_ys)] - for i in range(sel_coords.shape[0]): - rect = patches.Rectangle((sel_coords[i][0], sel_coords[i][1]), tile_size, tile_size) - rects.append(rect) pc = collection.PatchCollection(rects, match_original=True, cmap=cmap, alpha=.5, edgecolor=None) pc.set_array(np.array(attentions)) diff --git a/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png index 920af44c8..64cf40656 100644 --- a/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png +++ b/Tests/ML/test_data/histo_heatmaps/heatmap_overlay.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6b62a4a702209d7e865d393674ddb8d3c032484fc6138d68b6fbd837e084c5af -size 312150 +oid sha256:3ab9e04b3b6ff098ff0269f8c3ffbf1cad45c8c6635c68e7732a59ea9568e82f +size 314086