From 4cfbfe38f2ef879012371862fc42eaa949214566 Mon Sep 17 00:00:00 2001 From: Harshita Sharma Date: Thu, 27 Jan 2022 10:40:20 +0000 Subject: [PATCH 1/9] confusion matrix at test end --- InnerEye/ML/Histopathology/models/deepmil.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 14ca78c00..6aadeef34 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -14,7 +14,7 @@ from pytorch_lightning import LightningModule from torch import Tensor, argmax, mode, nn, no_grad, optim, round -from torchmetrics import AUROC, F1, Accuracy, Precision, Recall +from torchmetrics import AUROC, F1, Accuracy, Precision, Recall, ConfusionMatrix from InnerEye.Common import fixed_paths from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset, SlidesDataset @@ -148,13 +148,15 @@ def get_metrics(self) -> nn.ModuleDict: if self.n_classes > 1: return nn.ModuleDict({'accuracy': Accuracy(num_classes=self.n_classes, average='micro'), 'macro_accuracy': Accuracy(num_classes=self.n_classes, average='macro'), - 'weighted_accuracy': Accuracy(num_classes=self.n_classes, average='weighted')}) + 'weighted_accuracy': Accuracy(num_classes=self.n_classes, average='weighted'), + 'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes)}) else: return nn.ModuleDict({'accuracy': Accuracy(), 'auroc': AUROC(num_classes=self.n_classes), 'precision': Precision(), 'recall': Recall(), - 'f1score': F1()}) + 'f1score': F1(), + 'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes)}) def log_metrics(self, stage: str) -> None: @@ -162,7 +164,10 @@ def log_metrics(self, if stage not in valid_stages: raise Exception(f"Invalid stage. Chose one of {valid_stages}") for metric_name, metric_object in self.get_metrics_dict(stage).items(): - self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) + if not metric_name == "confusion_matrix": + self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) + # else: + # metric_object.compute() def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore with no_grad(): @@ -338,6 +343,13 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore fig = plot_scores_hist(results) self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png') + metrics_dict = self.get_metrics_dict('test') + print(metrics_dict) + metric_value = metrics_dict["confusion_matrix"].compute() + # We can't log tensors in the normal way - just print it to console + print('test/confusion matrix:') + print(np.array(metric_value.cpu())) + @staticmethod def save_figure(fig: plt.figure, figpath: Path) -> None: fig.savefig(figpath, bbox_inches='tight') From a399b5e97a6892d729fcaaac025863d75f01acab Mon Sep 17 00:00:00 2001 From: Harshita Sharma Date: Thu, 27 Jan 2022 12:30:31 +0000 Subject: [PATCH 2/9] cm figure and print output --- InnerEye/ML/Histopathology/models/deepmil.py | 25 +++++++++++++------ .../ML/Histopathology/utils/metrics_utils.py | 15 +++++++++++ .../classification/DeepSMILEPanda.py | 4 ++- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 6aadeef34..b6413e14d 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -19,7 +19,9 @@ from InnerEye.Common import fixed_paths from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset, SlidesDataset from InnerEye.ML.Histopathology.models.encoders import TileEncoder -from InnerEye.ML.Histopathology.utils.metrics_utils import select_k_tiles, plot_attention_tiles, plot_scores_hist, plot_heatmap_overlay, plot_slide +from InnerEye.ML.Histopathology.utils.metrics_utils import (select_k_tiles, plot_attention_tiles, + plot_scores_hist, plot_heatmap_overlay, + plot_slide, plot_normalized_confusion_matrix) from InnerEye.ML.Histopathology.utils.naming import ResultsKey from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict @@ -53,7 +55,8 @@ def __init__(self, verbose: bool = False, slide_dataset: SlidesDataset = None, tile_size: int = 224, - level: int = 1) -> None: + level: int = 1, + class_names: List[str] = None) -> None: """ :param label_column: Label key for input batch dictionary. :param n_classes: Number of output classes for MIL prediction. @@ -84,6 +87,11 @@ def __init__(self, self.encoder = encoder self.num_encoding = self.encoder.num_encoding + if class_names is not None: + self.class_names = class_names + else: + self.class_names = [str(i) for i in range(self.n_classes)] + # Optimiser hyperparameters self.l_rate = l_rate self.weight_decay = weight_decay @@ -166,8 +174,6 @@ def log_metrics(self, for metric_name, metric_object in self.get_metrics_dict(stage).items(): if not metric_name == "confusion_matrix": self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) - # else: - # metric_object.compute() def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore with no_grad(): @@ -343,12 +349,17 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore fig = plot_scores_hist(results) self.save_figure(fig=fig, figpath=outputs_fig_path / 'hist_scores.png') + print("Computing and saving confusion matrix...") metrics_dict = self.get_metrics_dict('test') - print(metrics_dict) - metric_value = metrics_dict["confusion_matrix"].compute() + cf_matrix = metrics_dict["confusion_matrix"].compute() + cf_matrix = np.array(cf_matrix.cpu()) # We can't log tensors in the normal way - just print it to console print('test/confusion matrix:') - print(np.array(metric_value.cpu())) + print(cf_matrix) + # Save the normalized confusion matrix as a figure in outputs + cf_matrix_n = cf_matrix/cf_matrix.sum(axis=1) + fig = plot_normalized_confusion_matrix(cf_matrix_n, self.class_names) + self.save_figure(fig=fig, figpath=outputs_fig_path / 'normalized_confusion_matrix.png') @staticmethod def save_figure(fig: plt.figure, figpath: Path) -> None: diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index 93f9a4a22..5f2e04c5c 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -10,6 +10,7 @@ import numpy as np import matplotlib.patches as patches import matplotlib.collections as collection +import seaborn as sns from InnerEye.ML.Histopathology.models.transforms import load_pil_image from InnerEye.ML.Histopathology.utils.naming import ResultsKey @@ -164,3 +165,17 @@ def plot_heatmap_overlay(slide: str, ax.add_collection(pc) plt.colorbar(pc, ax=ax) return fig + + +def plot_normalized_confusion_matrix(cm: np.ndarray, class_names: List[str]) -> plt.figure: + """Plots a normalized confusion matrix and returns the figure. + param cm: Normalized confusion matrix to be plotted. + param class_names: List of class names. + """ + fig, ax = plt.subplots() + ax = sns.heatmap(cm, annot=True, cmap='Blues', fmt=".2%") + ax.set_xlabel('Predicted') + ax.set_ylabel('True') + ax.xaxis.set_ticklabels(class_names) + ax.yaxis.set_ticklabels(class_names) + return fig \ No newline at end of file diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index c72dc3f66..68d952cf5 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -127,6 +127,7 @@ def create_model(self) -> DeepMILModule: # no-op IdentityEncoder to be used inside the model self.slide_dataset = self.get_slide_dataset() self.level = 1 + self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"] return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)), label_column=self.data_module.train_dataset.LABEL_COLUMN, n_classes=self.data_module.train_dataset.N_CLASSES, @@ -137,7 +138,8 @@ def create_model(self) -> DeepMILModule: adam_betas=self.adam_betas, slide_dataset=self.get_slide_dataset(), tile_size=self.tile_size, - level=self.level) + level=self.level, + class_names=self.class_names) def get_slide_dataset(self) -> PandaDataset: return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore From d66083af5735ca6fe8cad8f0acb0bb54268aad9f Mon Sep 17 00:00:00 2001 From: Harshita Sharma Date: Thu, 27 Jan 2022 15:03:55 +0000 Subject: [PATCH 3/9] add test for plot_normalized_confusion_matrix --- InnerEye/ML/Histopathology/models/deepmil.py | 7 +++++-- .../ML/Histopathology/utils/metrics_utils.py | 2 +- .../utils/test_metrics_utils.py | 19 ++++++++++++++++++- .../histo_heatmaps/confusion_matrix.png | 3 +++ 4 files changed, 27 insertions(+), 4 deletions(-) create mode 100644 Tests/ML/test_data/histo_heatmaps/confusion_matrix.png diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index b6413e14d..4304e0f19 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -90,7 +90,10 @@ def __init__(self, if class_names is not None: self.class_names = class_names else: - self.class_names = [str(i) for i in range(self.n_classes)] + if self.n_classes > 1: + self.class_names = [str(i) for i in range(self.n_classes)] + else: + self.class_names = ['0', '1'] # Optimiser hyperparameters self.l_rate = l_rate @@ -358,7 +361,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore print(cf_matrix) # Save the normalized confusion matrix as a figure in outputs cf_matrix_n = cf_matrix/cf_matrix.sum(axis=1) - fig = plot_normalized_confusion_matrix(cf_matrix_n, self.class_names) + fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=self.class_names) self.save_figure(fig=fig, figpath=outputs_fig_path / 'normalized_confusion_matrix.png') @staticmethod diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index 5f2e04c5c..c0d47a664 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -178,4 +178,4 @@ def plot_normalized_confusion_matrix(cm: np.ndarray, class_names: List[str]) -> ax.set_ylabel('True') ax.xaxis.set_ticklabels(class_names) ax.yaxis.set_ticklabels(class_names) - return fig \ No newline at end of file + return fig diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index c41c70280..1e57e9cc3 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -13,7 +13,7 @@ from torch.functional import Tensor import pytest -from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, plot_heatmap_overlay +from InnerEye.ML.Histopathology.utils.metrics_utils import plot_scores_hist, select_k_tiles, plot_slide, plot_heatmap_overlay, plot_normalized_confusion_matrix from InnerEye.ML.Histopathology.utils.naming import ResultsKey from InnerEye.ML.Histopathology.utils.heatmap_utils import location_selected_tiles from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path @@ -126,6 +126,23 @@ def test_plot_heatmap_overlay(test_output_dirs: OutputFolderForTests) -> None: # expected.write_bytes(file.read_bytes()) assert_binary_files_match(file, expected) + +def test_plot_normalized_confusion_matrix(test_output_dirs: OutputFolderForTests) -> None: + set_random_seed(0) + n_classes = 3 + cm = np.random.rand(n_classes, n_classes) + class_names = [str(i) for i in range(n_classes)] + fig = plot_normalized_confusion_matrix(cm=cm, class_names=class_names) + assert isinstance(fig, matplotlib.figure.Figure) + file = Path(test_output_dirs.root_dir) / "plot_confusion_matrix.png" + resize_and_save(5, 5, file) + assert file.exists() + expected = full_ml_test_data_path("histo_heatmaps") / "confusion_matrix.png" + # To update the stored results, uncomment this line: + # expected.write_bytes(file.read_bytes()) + assert_binary_files_match(file, expected) + + @pytest.mark.parametrize("level", [0, 1, 2]) def test_location_selected_tiles(level: int) -> None: set_random_seed(0) diff --git a/Tests/ML/test_data/histo_heatmaps/confusion_matrix.png b/Tests/ML/test_data/histo_heatmaps/confusion_matrix.png new file mode 100644 index 000000000..efec3f9c1 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/confusion_matrix.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8da4f5e8b876fd8d709850dd007e3f92beb923d8dd537c463c67768c7cbff59 +size 35031 From c60da3d1ab484e71375703cc78b5a407c0bdff5f Mon Sep 17 00:00:00 2001 From: Harshita Sharma Date: Thu, 27 Jan 2022 15:37:24 +0000 Subject: [PATCH 4/9] perclass accuracy log --- InnerEye/ML/Histopathology/models/deepmil.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 4304e0f19..676580ca7 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -163,11 +163,11 @@ def get_metrics(self) -> nn.ModuleDict: 'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes)}) else: return nn.ModuleDict({'accuracy': Accuracy(), - 'auroc': AUROC(num_classes=self.n_classes), - 'precision': Precision(), - 'recall': Recall(), - 'f1score': F1(), - 'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes)}) + 'auroc': AUROC(num_classes=self.n_classes), + 'precision': Precision(), + 'recall': Recall(), + 'f1score': F1(), + 'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes+1)}) def log_metrics(self, stage: str) -> None: @@ -177,6 +177,11 @@ def log_metrics(self, for metric_name, metric_object in self.get_metrics_dict(stage).items(): if not metric_name == "confusion_matrix": self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) + else: + metric_value = metric_object.compute() + metric_value_n = metric_value/metric_value.sum(axis=1) + for i in range(metric_value_n.shape[0]): + self.log(f'{stage}/{self.class_names[i]}', metric_value_n[i, i], on_epoch=True, on_step=False, logger=True, sync_dist=True) def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore with no_grad(): From 186a72bb9f46e14bebc181ea5ac2fd6b295226db Mon Sep 17 00:00:00 2001 From: Harshita Sharma Date: Fri, 28 Jan 2022 08:28:48 +0000 Subject: [PATCH 5/9] test for normalized cm --- InnerEye/ML/Histopathology/models/deepmil.py | 4 +-- .../classification/DeepSMILECrck.py | 26 ++++++++++++++----- .../utils/test_metrics_utils.py | 23 ++++++++++------ .../histo_heatmaps/confusion_matrix.png | 3 --- .../histo_heatmaps/confusion_matrix_1.png | 3 +++ .../histo_heatmaps/confusion_matrix_3.png | 3 +++ 6 files changed, 42 insertions(+), 20 deletions(-) delete mode 100644 Tests/ML/test_data/histo_heatmaps/confusion_matrix.png create mode 100644 Tests/ML/test_data/histo_heatmaps/confusion_matrix_1.png create mode 100644 Tests/ML/test_data/histo_heatmaps/confusion_matrix_3.png diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 676580ca7..a2f7998e1 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -179,7 +179,7 @@ def log_metrics(self, self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) else: metric_value = metric_object.compute() - metric_value_n = metric_value/metric_value.sum(axis=1) + metric_value_n = metric_value/metric_value.sum(axis=1, keepdims=True) for i in range(metric_value_n.shape[0]): self.log(f'{stage}/{self.class_names[i]}', metric_value_n[i, i], on_epoch=True, on_step=False, logger=True, sync_dist=True) @@ -365,7 +365,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore print('test/confusion matrix:') print(cf_matrix) # Save the normalized confusion matrix as a figure in outputs - cf_matrix_n = cf_matrix/cf_matrix.sum(axis=1) + cf_matrix_n = cf_matrix/cf_matrix.sum(axis=1, keepdims=True) fig = plot_normalized_confusion_matrix(cm=cf_matrix_n, class_names=self.class_names) self.save_figure(fig=fig, figpath=outputs_fig_path / 'normalized_confusion_matrix.png') diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py index 01384469b..d83221c80 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py @@ -25,6 +25,7 @@ from health_azure.utils import get_workspace from health_azure.utils import CheckpointDownloader from InnerEye.Common import fixed_paths +from InnerEye.ML.common import get_best_checkpoint_path from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import ( @@ -57,6 +58,7 @@ def __init__(self, **kwargs: Any) -> None: # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: num_epochs=16, + batch_size=8, # declared in WorkflowParams: number_of_cross_validation_splits=5, cross_validation_split_index=0, @@ -120,7 +122,6 @@ def get_data_module(self) -> TilesDataModule: batch_size=self.batch_size, transform=transform, cache_mode=self.cache_mode, - save_precache=self.save_precache, cache_dir=self.cache_dir, number_of_cross_validation_splits=self.number_of_cross_validation_splits, cross_validation_split_index=self.cross_validation_split_index, @@ -135,12 +136,23 @@ def get_path_to_best_checkpoint(self) -> Path: was applied there. """ # absolute path is required for registering the model. - return ( - fixed_paths.repository_root_directory() - / self.checkpoint_folder_path - / self.best_checkpoint_filename_with_suffix - ) - + absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(), + self.checkpoint_folder_path, + self.best_checkpoint_filename_with_suffix) + if absolute_checkpoint_path.is_file(): + return absolute_checkpoint_path + + absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(), + self.checkpoint_folder_path, + self.best_checkpoint_filename_with_suffix) + if absolute_checkpoint_path_parent.is_file(): + return absolute_checkpoint_path_parent + + checkpoint_path = get_best_checkpoint_path(Path(self.checkpoint_folder_path)) + if checkpoint_path.is_file(): + return checkpoint_path + + raise ValueError("Path to best checkpoint not found") class TcgaCrckImageNetMIL(DeepSMILECrck): def __init__(self, **kwargs: Any) -> None: diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index 1e57e9cc3..e63437b4d 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -127,19 +127,26 @@ def test_plot_heatmap_overlay(test_output_dirs: OutputFolderForTests) -> None: assert_binary_files_match(file, expected) -def test_plot_normalized_confusion_matrix(test_output_dirs: OutputFolderForTests) -> None: +@pytest.mark.parametrize("n_classes", [1, 3]) +def test_plot_normalized_confusion_matrix(test_output_dirs: OutputFolderForTests, n_classes: int) -> None: set_random_seed(0) - n_classes = 3 - cm = np.random.rand(n_classes, n_classes) - class_names = [str(i) for i in range(n_classes)] - fig = plot_normalized_confusion_matrix(cm=cm, class_names=class_names) + if n_classes > 1: + cm = np.random.randint(1, 1000, size=(n_classes, n_classes)) + class_names = [str(i) for i in range(n_classes)] + else: + cm = np.random.randint(1, 1000, size=(n_classes+1, n_classes+1)) + class_names = [str(i) for i in range(n_classes+1)] + cm_n = cm/cm.sum(axis=1, keepdims=True) + assert (cm_n <= 1).all() + + fig = plot_normalized_confusion_matrix(cm=cm_n, class_names=class_names) assert isinstance(fig, matplotlib.figure.Figure) - file = Path(test_output_dirs.root_dir) / "plot_confusion_matrix.png" + file = Path(test_output_dirs.root_dir) / f"plot_confusion_matrix_{n_classes}.png" resize_and_save(5, 5, file) assert file.exists() - expected = full_ml_test_data_path("histo_heatmaps") / "confusion_matrix.png" + expected = full_ml_test_data_path("histo_heatmaps") / f"confusion_matrix_{n_classes}.png" # To update the stored results, uncomment this line: - # expected.write_bytes(file.read_bytes()) + expected.write_bytes(file.read_bytes()) assert_binary_files_match(file, expected) diff --git a/Tests/ML/test_data/histo_heatmaps/confusion_matrix.png b/Tests/ML/test_data/histo_heatmaps/confusion_matrix.png deleted file mode 100644 index efec3f9c1..000000000 --- a/Tests/ML/test_data/histo_heatmaps/confusion_matrix.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b8da4f5e8b876fd8d709850dd007e3f92beb923d8dd537c463c67768c7cbff59 -size 35031 diff --git a/Tests/ML/test_data/histo_heatmaps/confusion_matrix_1.png b/Tests/ML/test_data/histo_heatmaps/confusion_matrix_1.png new file mode 100644 index 000000000..5539764a9 --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/confusion_matrix_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c3da44ec1ac48549495229aca1d94e9e78eda6c2dc40ade44d68d90efa1dc6e +size 22623 diff --git a/Tests/ML/test_data/histo_heatmaps/confusion_matrix_3.png b/Tests/ML/test_data/histo_heatmaps/confusion_matrix_3.png new file mode 100644 index 000000000..4cb60e9ce --- /dev/null +++ b/Tests/ML/test_data/histo_heatmaps/confusion_matrix_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2391edcf8e1fb4a2fdb17408d4526f43c7d9e836d366d2c83c08315ccd70ad8f +size 35291 From dd82dd3fb2caec1cc5c80d2d7804fe0254622e9f Mon Sep 17 00:00:00 2001 From: Harshita Sharma Date: Fri, 28 Jan 2022 11:06:25 +0000 Subject: [PATCH 6/9] changelog --- CHANGELOG.md | 1 + .../ML/configs/histo_configs/classification/DeepSMILECrck.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 68e1707a7..f11a127cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ jobs that run in AzureML. - ([#621](https://github.com/microsoft/InnerEye-DeepLearning/pull/621)) Add WSI preprocessing functions and enable tiling more generic slide datasets - ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs - ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL +- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL ### Changed - ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files. diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py index d83221c80..a9ca08268 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py @@ -58,7 +58,6 @@ def __init__(self, **kwargs: Any) -> None: # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: num_epochs=16, - batch_size=8, # declared in WorkflowParams: number_of_cross_validation_splits=5, cross_validation_split_index=0, From a60e62bbe9e8ec77ad8ee09a5d1dbf27ab8db8dc Mon Sep 17 00:00:00 2001 From: Harshita Sharma Date: Tue, 1 Feb 2022 08:16:07 +0000 Subject: [PATCH 7/9] PR comments addressed --- InnerEye/ML/Histopathology/models/deepmil.py | 46 ++--- InnerEye/ML/Histopathology/utils/naming.py | 10 + .../classification/DeepSMILECrck.py | 173 ------------------ .../classification/DeepSMILEPanda.py | 2 +- .../utils/test_metrics_utils.py | 2 +- 5 files changed, 36 insertions(+), 197 deletions(-) delete mode 100644 InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index a2f7998e1..6fcc1cbc9 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -22,11 +22,8 @@ from InnerEye.ML.Histopathology.utils.metrics_utils import (select_k_tiles, plot_attention_tiles, plot_scores_hist, plot_heatmap_overlay, plot_slide, plot_normalized_confusion_matrix) -from InnerEye.ML.Histopathology.utils.naming import ResultsKey - +from InnerEye.ML.Histopathology.utils.naming import SlideKey, ResultsKey, MetricsKey from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict -from InnerEye.ML.Histopathology.utils.naming import SlideKey - RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN] @@ -56,10 +53,10 @@ def __init__(self, slide_dataset: SlidesDataset = None, tile_size: int = 224, level: int = 1, - class_names: List[str] = None) -> None: + class_names: Optional[List[str]] = None) -> None: """ :param label_column: Label key for input batch dictionary. - :param n_classes: Number of output classes for MIL prediction. + :param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1. :param encoder: The tile encoder to use for feature extraction. If no encoding is needed, you should use `IdentityEncoder`. :param pooling_layer: Type of pooling to use in multi-instance aggregation. Should be a @@ -70,10 +67,11 @@ def __init__(self, :param l_rate: Optimiser learning rate. :param weight_decay: Weight decay parameter for L2 regularisation. :param adam_betas: Beta parameters for Adam optimiser. - :param verbose: if True statements about memory usage are output at each step + :param verbose: if True statements about memory usage are output at each step. :param slide_dataset: Slide dataset object, if available. :param tile_size: The size of each tile (default=224). :param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1). + :param class_names: The names of the classes if available (default=None). """ super().__init__() @@ -93,7 +91,11 @@ def __init__(self, if self.n_classes > 1: self.class_names = [str(i) for i in range(self.n_classes)] else: - self.class_names = ['0', '1'] + self.class_names = ['0', '1'] + if self.n_classes > 1 and len(self.class_names) != self.n_classes: + raise ValueError(f"Mismatch in number of class names ({self.class_names}) and number of classes ({self.n_classes})") + if self.n_classes == 1 and len(self.class_names) != 2: + raise ValueError(f"Mismatch in number of class names ({self.class_names}) and number of classes ({self.n_classes+1})") # Optimiser hyperparameters self.l_rate = l_rate @@ -157,17 +159,17 @@ def get_bag_label(labels: Tensor) -> Tensor: def get_metrics(self) -> nn.ModuleDict: if self.n_classes > 1: - return nn.ModuleDict({'accuracy': Accuracy(num_classes=self.n_classes, average='micro'), - 'macro_accuracy': Accuracy(num_classes=self.n_classes, average='macro'), - 'weighted_accuracy': Accuracy(num_classes=self.n_classes, average='weighted'), - 'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes)}) + return nn.ModuleDict({MetricsKey.ACC: Accuracy(num_classes=self.n_classes, average='micro'), + MetricsKey.ACC_MACRO: Accuracy(num_classes=self.n_classes, average='macro'), + MetricsKey.ACC_WEIGHTED: Accuracy(num_classes=self.n_classes, average='weighted'), + MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes)}) else: - return nn.ModuleDict({'accuracy': Accuracy(), - 'auroc': AUROC(num_classes=self.n_classes), - 'precision': Precision(), - 'recall': Recall(), - 'f1score': F1(), - 'confusion_matrix': ConfusionMatrix(num_classes=self.n_classes+1)}) + return nn.ModuleDict({MetricsKey.ACC: Accuracy(), + MetricsKey.AUROC: AUROC(num_classes=self.n_classes), + MetricsKey.PRECISION: Precision(), + MetricsKey.RECALL: Recall(), + MetricsKey.F1: F1(), + MetricsKey.CONF_MATRIX: ConfusionMatrix(num_classes=self.n_classes+1)}) def log_metrics(self, stage: str) -> None: @@ -175,13 +177,13 @@ def log_metrics(self, if stage not in valid_stages: raise Exception(f"Invalid stage. Chose one of {valid_stages}") for metric_name, metric_object in self.get_metrics_dict(stage).items(): - if not metric_name == "confusion_matrix": - self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) - else: + if metric_name == MetricsKey.CONF_MATRIX: metric_value = metric_object.compute() metric_value_n = metric_value/metric_value.sum(axis=1, keepdims=True) for i in range(metric_value_n.shape[0]): self.log(f'{stage}/{self.class_names[i]}', metric_value_n[i, i], on_epoch=True, on_step=False, logger=True, sync_dist=True) + else: + self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore with no_grad(): @@ -359,7 +361,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore print("Computing and saving confusion matrix...") metrics_dict = self.get_metrics_dict('test') - cf_matrix = metrics_dict["confusion_matrix"].compute() + cf_matrix = metrics_dict[MetricsKey.CONF_MATRIX].compute() cf_matrix = np.array(cf_matrix.cpu()) # We can't log tensors in the normal way - just print it to console print('test/confusion matrix:') diff --git a/InnerEye/ML/Histopathology/utils/naming.py b/InnerEye/ML/Histopathology/utils/naming.py index 5b7273ad9..f499b7b51 100644 --- a/InnerEye/ML/Histopathology/utils/naming.py +++ b/InnerEye/ML/Histopathology/utils/naming.py @@ -54,3 +54,13 @@ class ResultsKey(str, Enum): TILE_X = "x" TILE_Y = "y" + +class MetricsKey(str, Enum): + ACC = 'accuracy' + ACC_MACRO = 'macro_accuracy' + ACC_WEIGHTED = 'weighted_accuracy' + CONF_MATRIX = 'confusion_matrix' + AUROC = 'auroc' + PRECISION = 'precision' + RECALL = 'recall' + F1 = 'f1score' diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py deleted file mode 100644 index a9ca08268..000000000 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py +++ /dev/null @@ -1,173 +0,0 @@ -# ------------------------------------------------------------------------------------------ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. -# ------------------------------------------------------------------------------------------ - -"""DeepSMILECrck is the container for experiments relating to DeepSMILE using the TCGA-CRCk dataset. -Run using `python InnerEyePrivate/ML/runner.py --model=DeepSMILECrck --encoder_type=` - -For convenience, this module also defines encoder-specific containers that can be invoked without -additional arguments, e.g. `python InnerEyePrivate/ML/runner.py --model=TcgaCrckImageNetMIL` - -Reference: -- Schirris (2021). DeepSMILE: Self-supervised heterogeneity-aware multiple instance learning for DNA -damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405 -""" -from pathlib import Path -from typing import Any, List -import os - -from monai.transforms import Compose -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.callbacks import Callback - -from health_ml.networks.layers.attention_layers import GatedAttentionLayer -from health_azure.utils import get_workspace -from health_azure.utils import CheckpointDownloader -from InnerEye.Common import fixed_paths -from InnerEye.ML.common import get_best_checkpoint_path -from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL -from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule -from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import ( - TcgaCrckTilesDataModule, -) -from InnerEye.ML.Histopathology.models.encoders import ( - HistoSSLEncoder, - ImageNetEncoder, - ImageNetSimCLREncoder, - InnerEyeSSLEncoder, -) -from InnerEye.ML.Histopathology.models.transforms import ( - EncodeTilesBatchd, - LoadTilesBatchd, -) -from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import ( - TcgaCrck_TilesDataset, -) - - -class DeepSMILECrck(BaseMIL): - def __init__(self, **kwargs: Any) -> None: - # Define dictionary with default params that can be overriden from subclasses or CLI - default_kwargs = dict( - # declared in BaseMIL: - pooling_type=GatedAttentionLayer.__name__, - # declared in DatasetParams: - local_dataset=Path("/tmp/datasets/TCGA-CRCk"), - azure_dataset_id="TCGA-CRCk", - # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI - # declared in TrainerParams: - num_epochs=16, - # declared in WorkflowParams: - number_of_cross_validation_splits=5, - cross_validation_split_index=0, - # declared in OptimizerParams: - l_rate=5e-4, - weight_decay=1e-4, - adam_betas=(0.9, 0.99), - ) - default_kwargs.update(kwargs) - super().__init__(**default_kwargs) - - self.best_checkpoint_filename = "checkpoint_max_val_auroc" - self.best_checkpoint_filename_with_suffix = ( - self.best_checkpoint_filename + ".ckpt" - ) - self.checkpoint_folder_path = "outputs/checkpoints/" - - best_checkpoint_callback = ModelCheckpoint( - dirpath=self.checkpoint_folder_path, - monitor="val/auroc", - filename=self.best_checkpoint_filename, - auto_insert_metric_name=False, - mode="max", - ) - self.callbacks = best_checkpoint_callback - - @property - def cache_dir(self) -> Path: - return Path( - f"/tmp/innereye_cache1/{self.__class__.__name__}-{self.encoder_type}/" - ) - - def setup(self) -> None: - if self.encoder_type == InnerEyeSSLEncoder.__name__: - from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint_crck_4ws - self.downloader = CheckpointDownloader( - azure_config_json_path=get_workspace(), - run_id=innereye_ssl_checkpoint_crck_4ws, - checkpoint_filename="best_checkpoint.ckpt", - download_dir="outputs/", - remote_checkpoint_dir=Path("outputs/checkpoints") - ) - os.chdir(fixed_paths.repository_parent_directory()) - self.downloader.download_checkpoint_if_necessary() - - self.encoder = self.get_encoder() - self.encoder.cuda() - self.encoder.eval() - - def get_data_module(self) -> TilesDataModule: - image_key = TcgaCrck_TilesDataset.IMAGE_COLUMN - transform = Compose( - [ - LoadTilesBatchd(image_key, progress=True), - EncodeTilesBatchd(image_key, self.encoder), - ] - ) - return TcgaCrckTilesDataModule( - root_path=self.local_dataset, - max_bag_size=self.max_bag_size, - batch_size=self.batch_size, - transform=transform, - cache_mode=self.cache_mode, - cache_dir=self.cache_dir, - number_of_cross_validation_splits=self.number_of_cross_validation_splits, - cross_validation_split_index=self.cross_validation_split_index, - ) - - def get_callbacks(self) -> List[Callback]: - return super().get_callbacks() + [self.callbacks] - - def get_path_to_best_checkpoint(self) -> Path: - """ - Returns the full path to a checkpoint file that was found to be best during training, whatever criterion - was applied there. - """ - # absolute path is required for registering the model. - absolute_checkpoint_path = Path(fixed_paths.repository_root_directory(), - self.checkpoint_folder_path, - self.best_checkpoint_filename_with_suffix) - if absolute_checkpoint_path.is_file(): - return absolute_checkpoint_path - - absolute_checkpoint_path_parent = Path(fixed_paths.repository_parent_directory(), - self.checkpoint_folder_path, - self.best_checkpoint_filename_with_suffix) - if absolute_checkpoint_path_parent.is_file(): - return absolute_checkpoint_path_parent - - checkpoint_path = get_best_checkpoint_path(Path(self.checkpoint_folder_path)) - if checkpoint_path.is_file(): - return checkpoint_path - - raise ValueError("Path to best checkpoint not found") - -class TcgaCrckImageNetMIL(DeepSMILECrck): - def __init__(self, **kwargs: Any) -> None: - super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs) - - -class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck): - def __init__(self, **kwargs: Any) -> None: - super().__init__(encoder_type=ImageNetSimCLREncoder.__name__, **kwargs) - - -class TcgaCrckInnerEyeSSLMIL(DeepSMILECrck): - def __init__(self, **kwargs: Any) -> None: - super().__init__(encoder_type=InnerEyeSSLEncoder.__name__, **kwargs) - - -class TcgaCrckHistoSSLMIL(DeepSMILECrck): - def __init__(self, **kwargs: Any) -> None: - super().__init__(encoder_type=HistoSSLEncoder.__name__, **kwargs) diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 68d952cf5..6b75b9003 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -49,7 +49,7 @@ def __init__(self, **kwargs: Any) -> None: extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")], # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: - num_epochs=200, + num_epochs=2, # use_mixed_precision = True, # declared in WorkflowParams: diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index e63437b4d..c226519d8 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -146,7 +146,7 @@ def test_plot_normalized_confusion_matrix(test_output_dirs: OutputFolderForTests assert file.exists() expected = full_ml_test_data_path("histo_heatmaps") / f"confusion_matrix_{n_classes}.png" # To update the stored results, uncomment this line: - expected.write_bytes(file.read_bytes()) + # expected.write_bytes(file.read_bytes()) assert_binary_files_match(file, expected) From 88e2f39af58687ef7a981928f19b53021c70cfd9 Mon Sep 17 00:00:00 2001 From: Harshita Sharma Date: Tue, 1 Feb 2022 08:27:02 +0000 Subject: [PATCH 8/9] restore deepsmilecrck --- .../classification/DeepSMILECrck.py | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py new file mode 100644 index 000000000..01384469b --- /dev/null +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILECrck.py @@ -0,0 +1,162 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +"""DeepSMILECrck is the container for experiments relating to DeepSMILE using the TCGA-CRCk dataset. +Run using `python InnerEyePrivate/ML/runner.py --model=DeepSMILECrck --encoder_type=` + +For convenience, this module also defines encoder-specific containers that can be invoked without +additional arguments, e.g. `python InnerEyePrivate/ML/runner.py --model=TcgaCrckImageNetMIL` + +Reference: +- Schirris (2021). DeepSMILE: Self-supervised heterogeneity-aware multiple instance learning for DNA +damage response defect classification directly from H&E whole-slide images. arXiv:2107.09405 +""" +from pathlib import Path +from typing import Any, List +import os + +from monai.transforms import Compose +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Callback + +from health_ml.networks.layers.attention_layers import GatedAttentionLayer +from health_azure.utils import get_workspace +from health_azure.utils import CheckpointDownloader +from InnerEye.Common import fixed_paths +from InnerEye.ML.configs.histo_configs.classification.BaseMIL import BaseMIL +from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule +from InnerEye.ML.Histopathology.datamodules.tcga_crck_module import ( + TcgaCrckTilesDataModule, +) +from InnerEye.ML.Histopathology.models.encoders import ( + HistoSSLEncoder, + ImageNetEncoder, + ImageNetSimCLREncoder, + InnerEyeSSLEncoder, +) +from InnerEye.ML.Histopathology.models.transforms import ( + EncodeTilesBatchd, + LoadTilesBatchd, +) +from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import ( + TcgaCrck_TilesDataset, +) + + +class DeepSMILECrck(BaseMIL): + def __init__(self, **kwargs: Any) -> None: + # Define dictionary with default params that can be overriden from subclasses or CLI + default_kwargs = dict( + # declared in BaseMIL: + pooling_type=GatedAttentionLayer.__name__, + # declared in DatasetParams: + local_dataset=Path("/tmp/datasets/TCGA-CRCk"), + azure_dataset_id="TCGA-CRCk", + # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI + # declared in TrainerParams: + num_epochs=16, + # declared in WorkflowParams: + number_of_cross_validation_splits=5, + cross_validation_split_index=0, + # declared in OptimizerParams: + l_rate=5e-4, + weight_decay=1e-4, + adam_betas=(0.9, 0.99), + ) + default_kwargs.update(kwargs) + super().__init__(**default_kwargs) + + self.best_checkpoint_filename = "checkpoint_max_val_auroc" + self.best_checkpoint_filename_with_suffix = ( + self.best_checkpoint_filename + ".ckpt" + ) + self.checkpoint_folder_path = "outputs/checkpoints/" + + best_checkpoint_callback = ModelCheckpoint( + dirpath=self.checkpoint_folder_path, + monitor="val/auroc", + filename=self.best_checkpoint_filename, + auto_insert_metric_name=False, + mode="max", + ) + self.callbacks = best_checkpoint_callback + + @property + def cache_dir(self) -> Path: + return Path( + f"/tmp/innereye_cache1/{self.__class__.__name__}-{self.encoder_type}/" + ) + + def setup(self) -> None: + if self.encoder_type == InnerEyeSSLEncoder.__name__: + from InnerEye.ML.configs.histo_configs.run_ids import innereye_ssl_checkpoint_crck_4ws + self.downloader = CheckpointDownloader( + azure_config_json_path=get_workspace(), + run_id=innereye_ssl_checkpoint_crck_4ws, + checkpoint_filename="best_checkpoint.ckpt", + download_dir="outputs/", + remote_checkpoint_dir=Path("outputs/checkpoints") + ) + os.chdir(fixed_paths.repository_parent_directory()) + self.downloader.download_checkpoint_if_necessary() + + self.encoder = self.get_encoder() + self.encoder.cuda() + self.encoder.eval() + + def get_data_module(self) -> TilesDataModule: + image_key = TcgaCrck_TilesDataset.IMAGE_COLUMN + transform = Compose( + [ + LoadTilesBatchd(image_key, progress=True), + EncodeTilesBatchd(image_key, self.encoder), + ] + ) + return TcgaCrckTilesDataModule( + root_path=self.local_dataset, + max_bag_size=self.max_bag_size, + batch_size=self.batch_size, + transform=transform, + cache_mode=self.cache_mode, + save_precache=self.save_precache, + cache_dir=self.cache_dir, + number_of_cross_validation_splits=self.number_of_cross_validation_splits, + cross_validation_split_index=self.cross_validation_split_index, + ) + + def get_callbacks(self) -> List[Callback]: + return super().get_callbacks() + [self.callbacks] + + def get_path_to_best_checkpoint(self) -> Path: + """ + Returns the full path to a checkpoint file that was found to be best during training, whatever criterion + was applied there. + """ + # absolute path is required for registering the model. + return ( + fixed_paths.repository_root_directory() + / self.checkpoint_folder_path + / self.best_checkpoint_filename_with_suffix + ) + + +class TcgaCrckImageNetMIL(DeepSMILECrck): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=ImageNetEncoder.__name__, **kwargs) + + +class TcgaCrckImageNetSimCLRMIL(DeepSMILECrck): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=ImageNetSimCLREncoder.__name__, **kwargs) + + +class TcgaCrckInnerEyeSSLMIL(DeepSMILECrck): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=InnerEyeSSLEncoder.__name__, **kwargs) + + +class TcgaCrckHistoSSLMIL(DeepSMILECrck): + def __init__(self, **kwargs: Any) -> None: + super().__init__(encoder_type=HistoSSLEncoder.__name__, **kwargs) From d34d50c2531185819fd830294824909f3d2ba014 Mon Sep 17 00:00:00 2001 From: Harshita Sharma Date: Tue, 1 Feb 2022 10:27:34 +0000 Subject: [PATCH 9/9] replace self.log with log_on_epoch --- InnerEye/ML/Histopathology/models/deepmil.py | 5 +++-- .../configs/histo_configs/classification/DeepSMILEPanda.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 6fcc1cbc9..8e001e080 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -24,6 +24,7 @@ plot_slide, plot_normalized_confusion_matrix) from InnerEye.ML.Histopathology.utils.naming import SlideKey, ResultsKey, MetricsKey from InnerEye.ML.Histopathology.utils.viz_utils import load_image_dict +from health_ml.utils import log_on_epoch RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN] @@ -181,9 +182,9 @@ def log_metrics(self, metric_value = metric_object.compute() metric_value_n = metric_value/metric_value.sum(axis=1, keepdims=True) for i in range(metric_value_n.shape[0]): - self.log(f'{stage}/{self.class_names[i]}', metric_value_n[i, i], on_epoch=True, on_step=False, logger=True, sync_dist=True) + log_on_epoch(self, f'{stage}/{self.class_names[i]}', metric_value_n[i, i]) else: - self.log(f'{stage}/{metric_name}', metric_object, on_epoch=True, on_step=False, logger=True, sync_dist=True) + log_on_epoch(self, f'{stage}/{metric_name}', metric_object) def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore with no_grad(): diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 6b75b9003..68d952cf5 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -49,7 +49,7 @@ def __init__(self, **kwargs: Any) -> None: extra_local_dataset_paths=[Path("/tmp/datasets/PANDA")], # To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI # declared in TrainerParams: - num_epochs=2, + num_epochs=200, # use_mixed_precision = True, # declared in WorkflowParams: