Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs that run in AzureML.
- ([#621](https://github.com/microsoft/InnerEye-DeepLearning/pull/621)) Add WSI preprocessing functions and enable tiling more generic slide datasets
- ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs
- ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL

### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
Expand Down
70 changes: 52 additions & 18 deletions InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@

from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, no_grad, optim, round
from torchmetrics import AUROC, F1, Accuracy, Precision, Recall
from torchmetrics import AUROC, F1, Accuracy, Precision, Recall, ConfusionMatrix

from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset, SlidesDataset
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_attention_tiles, plot_scores_hist, plot_heatmap_overlay, plot_slide
from InnerEye.ML.Histopathology.utils.naming import ResultsKey

from InnerEye.ML.Histopathology.utils.metrics_utils import (select_k_tiles, plot_attention_tiles,
plot_scores_hist, plot_heatmap_overlay,
plot_slide, plot_normalized_confusion_matrix)
from InnerEye.ML.Histopathology.utils.naming import SlideKey, ResultsKey, MetricsKey
from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict
from InnerEye.ML.Histopathology.utils.naming import SlideKey

from health_ml.utils import log_on_epoch

RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB,
ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN]
Expand Down Expand Up @@ -53,10 +53,11 @@ def __init__(self,
verbose: bool = False,
slide_dataset: SlidesDataset = None,
tile_size: int = 224,
level: int = 1) -> None:
level: int = 1,
class_names: Optional[List[str]] = None) -> None:
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1.
:param encoder: The tile encoder to use for feature extraction. If no encoding is needed,
you should use `IdentityEncoder`.
:param pooling_layer: Type of pooling to use in multi-instance aggregation. Should be a
Expand All @@ -67,10 +68,11 @@ def __init__(self,
:param l_rate: Optimiser learning rate.
:param weight_decay: Weight decay parameter for L2 regularisation.
:param adam_betas: Beta parameters for Adam optimiser.
:param verbose: if True statements about memory usage are output at each step
:param verbose: if True statements about memory usage are output at each step.
:param slide_dataset: Slide dataset object, if available.
:param tile_size: The size of each tile (default=224).
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
:param class_names: The names of the classes if available (default=None).
"""
super().__init__()

Expand All @@ -84,6 +86,18 @@ def __init__(self,
self.encoder = encoder
self.num_encoding = self.encoder.num_encoding

if class_names is not None:
self.class_names = class_names
else:
if self.n_classes > 1:
self.class_names = [str(i) for i in range(self.n_classes)]
else:
self.class_names = ['0', '1']
if self.n_classes > 1 and len(self.class_names) != self.n_classes:
raise ValueError(f"Mismatch in number of class names ({self.class_names}) and number of classes ({self.n_classes})")
if self.n_classes == 1 and len(self.class_names) != 2:
raise ValueError(f"Mismatch in number of class names ({self.class_names}) and number of classes ({self.n_classes+1})")

# Optimiser hyperparameters
self.l_rate = l_rate
self.weight_decay = weight_decay
Expand Down Expand Up @@ -146,23 +160,31 @@ def get_bag_label(labels: Tensor) -> Tensor:

def get_metrics(self) -> nn.ModuleDict:
if self.n_classes > 1:
return nn.ModuleDict({'accuracy': Accuracy(num_classes=self.n_classes, average='micro'),
'macro_accuracy': Accuracy(num_classes=self.n_classes, average='macro'),
'weighted_accuracy': Accuracy(num_classes=self.n_classes, average='weighted')})
return nn.ModuleDict({MetricsKey.ACC: Accuracy(num_classes=self.n_classes, average='micro'),
MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'),
MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)})
else:
return nn.ModuleDict({'accuracy': Accuracy(),
'auroc': AUROC(num_classes=self.n_classes),
'precision': Precision(),
'recall': Recall(),
'f1score': F1()})
return nn.ModuleDict({MetricsKey.ACC: Accuracy(),
MetricsKey.AUROC: AUROC(num_classes=self.n_classes),
MetricsKey.PRECISION: Precision(),
MetricsKey.RECALL: Recall(),
MetricsKey.F1: F1(),
MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes+1)})

def log_metrics(self,
stage: str) -> None:
valid_stages = ['train', 'test', 'val']
if stage not in valid_stages:
raise Exception(f"Invalid stage. Chose one of {valid_stages}")
for metric_name, metric_object in self.get_metrics_dict(stage).items():
self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True)
if metric_name == MetricsKey.CONF_MATRIX:
metric_value = metric_object.compute()
metric_value_n = metric_value/metric_value.sum(axis=1, keepdims=True)
for i in range(metric_value_n.shape[0]):
log_on_epoch(self, f'{stage}/{self.class_names[i]}', metric_value_n[i, i])
else:
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)

def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
with no_grad():
Expand Down Expand Up @@ -338,6 +360,18 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore
fig = plot_scores_hist(results)
self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png')

print("Computing and saving confusion matrix...")
metrics_dict = self.get_metrics_dict('test')
cf_matrix = metrics_dict[MetricsKey.CONF_MATRIX].compute()
cf_matrix = np.array(cf_matrix.cpu())
# We can't log tensors in the normal way - just print it to console
print('test/confusion matrix:')
print(cf_matrix)
# Save the normalized confusion matrix as a figure in outputs
cf_matrix_n = cf_matrix/cf_matrix.sum(axis=1, keepdims=True)
fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=self.class_names)
self.save_figure(fig=fig, figpath=outputs_fig_path / 'normalized_confusion_matrix.png')

@staticmethod
def save_figure(fig: plt.figure, figpath: Path) -> None:
fig.savefig(figpath, bbox_inches='tight')
Expand Down
15 changes: 15 additions & 0 deletions InnerEye/ML/Histopathology/utils/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import matplotlib.patches as patches
import matplotlib.collections as collection
import seaborn as sns

from InnerEye.ML.Histopathology.models.transforms import load_pil_image
from InnerEye.ML.Histopathology.utils.naming import ResultsKey
Expand Down Expand Up @@ -164,3 +165,17 @@ def plot_heatmap_overlay(slide: str,
ax.add_collection(pc)
plt.colorbar(pc, ax=ax)
return fig


def plot_normalized_confusion_matrix(cm: np.ndarray, class_names: List[str]) -> plt.figure:
"""Plots a normalized confusion matrix and returns the figure.
param cm: Normalized confusion matrix to be plotted.
param class_names: List of class names.
"""
fig, ax = plt.subplots()
ax = sns.heatmap(cm, annot=True, cmap='Blues', fmt=".2%")
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.xaxis.set_ticklabels(class_names)
ax.yaxis.set_ticklabels(class_names)
return fig
10 changes: 10 additions & 0 deletions InnerEye/ML/Histopathology/utils/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,13 @@ class ResultsKey(str, Enum):
TILE_X = "x"
TILE_Y = "y"


class MetricsKey(str, Enum):
ACC = 'accuracy'
ACC_MACRO = 'macro_accuracy'
ACC_WEIGHTED = 'weighted_accuracy'
CONF_MATRIX = 'confusion_matrix'
AUROC = 'auroc'
PRECISION = 'precision'
RECALL = 'recall'
F1 = 'f1score'
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def create_model(self) -> DeepMILModule:
# no-op IdentityEncoder to be used inside the model
self.slide_dataset = self.get_slide_dataset()
self.level = 1
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
Expand All @@ -137,7 +138,8 @@ def create_model(self) -> DeepMILModule:
adam_betas=self.adam_betas,
slide_dataset=self.get_slide_dataset(),
tile_size=self.tile_size,
level=self.level)
level=self.level,
class_names=self.class_names)

def get_slide_dataset(self) -> PandaDataset:
return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore
Expand Down
26 changes: 25 additions & 1 deletion Tests/ML/histopathology/utils/test_metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.functional import Tensor
import pytest

from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, plot_heatmap_overlay
from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, plot_heatmap_overlay, plot_normalized_confusion_matrix
from InnerEye.ML.Histopathology.utils.naming import ResultsKey
from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
Expand Down Expand Up @@ -126,6 +126,30 @@ def test_plot_heatmap_overlay(test_output_dirs: OutputFolderForTests) -> None:
# expected.write_bytes(file.read_bytes())
assert_binary_files_match(file, expected)


@pytest.mark.parametrize("n_classes", [1, 3])
def test_plot_normalized_confusion_matrix(test_output_dirs: OutputFolderForTests, n_classes: int) -> None:
set_random_seed(0)
if n_classes > 1:
cm = np.random.randint(1, 1000, size=(n_classes, n_classes))
class_names = [str(i) for i in range(n_classes)]
else:
cm = np.random.randint(1, 1000, size=(n_classes+1, n_classes+1))
class_names = [str(i) for i in range(n_classes+1)]
cm_n = cm/cm.sum(axis=1, keepdims=True)
assert (cm_n <= 1).all()

fig = plot_normalized_confusion_matrix(cm=cm_n, class_names=class_names)
assert isinstance(fig, matplotlib.figure.Figure)
file = Path(test_output_dirs.root_dir) / f"plot_confusion_matrix_{n_classes}.png"
resize_and_save(5, 5, file)
assert file.exists()
expected = full_ml_test_data_path("histo_heatmaps") / f"confusion_matrix_{n_classes}.png"
# To update the stored results, uncomment this line:
# expected.write_bytes(file.read_bytes())
assert_binary_files_match(file, expected)


@pytest.mark.parametrize("level", [0, 1, 2])
def test_location_selected_tiles(level: int) -> None:
set_random_seed(0)
Expand Down
3 changes: 3 additions & 0 deletions Tests/ML/test_data/histo_heatmaps/confusion_matrix_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions Tests/ML/test_data/histo_heatmaps/confusion_matrix_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.