diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c0d3520f..08d862099 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,7 @@ jobs that run in AzureML. - ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup. - ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task. - ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling. +- ([#679](https://github.com/microsoft/InnerEye-DeepLearning/pull/679)) Add FP and TN slides/tiles to DeepMIL outputs and extend outputs to multi-class problems. ### Changed - ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3. diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 13647dbe2..692033bc9 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -27,7 +27,7 @@ from health_ml.utils import log_on_epoch RESULTS_COLS = [ResultsKey.SLIDE_ID, ResultsKey.TILE_ID, ResultsKey.IMAGE_PATH, ResultsKey.PROB, - ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN] + ResultsKey.CLASS_PROBS, ResultsKey.PRED_LABEL, ResultsKey.TRUE_LABEL, ResultsKey.BAG_ATTN] def _format_cuda_memory_stats() -> str: @@ -242,21 +242,28 @@ def _shared_step(self, batch: Dict, batch_idx: int, stage: str) -> Dict[ResultsK predicted_probs = self.activation_fn(bag_logits) if self.n_classes > 1: predicted_labels = argmax(predicted_probs, dim=1) + probs_perclass = predicted_probs else: predicted_labels = round(predicted_probs) + probs_perclass = Tensor([[1.0 - predicted_probs[i][0].item(), predicted_probs[i][0].item()] for i in range(len(predicted_probs))]) loss = loss.view(-1, 1) predicted_labels = predicted_labels.view(-1, 1) - predicted_probs = predicted_probs.view(-1, 1) + if self.n_classes == 1: + predicted_probs = predicted_probs.view(-1, 1) bag_labels = bag_labels.view(-1, 1) results = dict() for metric_object in self.get_metrics_dict(stage).values(): - metric_object.update(predicted_probs, bag_labels) + if self.n_classes > 1: + metric_object.update(predicted_probs, bag_labels.squeeze()) + else: + metric_object.update(predicted_probs, bag_labels) results.update({ResultsKey.SLIDE_ID: batch[TilesDataset.SLIDE_ID_COLUMN], ResultsKey.TILE_ID: batch[TilesDataset.TILE_ID_COLUMN], ResultsKey.IMAGE_PATH: batch[TilesDataset.PATH_COLUMN], ResultsKey.LOSS: loss, - ResultsKey.PROB: predicted_probs, ResultsKey.PRED_LABEL: predicted_labels, + ResultsKey.PROB: predicted_probs, ResultsKey.CLASS_PROBS: probs_perclass, + ResultsKey.PRED_LABEL: predicted_labels, ResultsKey.TRUE_LABEL: bag_labels, ResultsKey.BAG_ATTN: bag_attn_list, ResultsKey.IMAGE: batch[TilesDataset.IMAGE_COLUMN]}) @@ -339,11 +346,21 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore torch.save(features_list, encoded_features_filename) print("Selecting tiles ...") - fn_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'highest_att')) - fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('lowest_pred', 'lowest_att')) - tp_top_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'highest_att')) - tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=1, n_tiles=10, select=('highest_pred', 'lowest_att')) - report_cases = {'TP': [tp_top_tiles, tp_bottom_tiles], 'FN': [fn_top_tiles, fn_bottom_tiles]} + # Class 0 + tn_top_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('highest_pred', 'highest_att')) + tn_bottom_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('highest_pred', 'lowest_att')) + fp_top_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('lowest_pred', 'highest_att')) + fp_bottom_tiles = select_k_tiles(results, n_slides=10, label=0, n_tiles=10, select=('lowest_pred', 'lowest_att')) + report_cases = {'TN': [tn_top_tiles, tn_bottom_tiles], 'FP': [fp_top_tiles, fp_bottom_tiles]} + + # Class 1 to n_classes-1 + n_classes_to_select = self.n_classes if self.n_classes > 1 else 2 + for i in range(1, n_classes_to_select): + fn_top_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('lowest_pred', 'highest_att')) + fn_bottom_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('lowest_pred', 'lowest_att')) + tp_top_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('highest_pred', 'highest_att')) + tp_bottom_tiles = select_k_tiles(results, n_slides=10, label=i, n_tiles=10, select=('highest_pred', 'lowest_att')) + report_cases.update({'TP_'+str(i): [tp_top_tiles, tp_bottom_tiles], 'FN_'+str(i): [fn_top_tiles, fn_bottom_tiles]}) for key in report_cases.keys(): print(f"Plotting {key} (tiles, thumbnails, attention heatmaps)...") @@ -397,13 +414,19 @@ def normalize_dict_for_df(dict_old: Dict[str, Any], use_gpu: bool) -> Dict: # these steps are required to convert the dictionary to pandas dataframe. device = 'cuda' if use_gpu else 'cpu' dict_new = dict() + bag_size = len(dict_old[ResultsKey.SLIDE_ID]) for key, value in dict_old.items(): - if isinstance(value, Tensor): - value = value.squeeze(0).to(device).numpy() - if value.ndim == 0: - bag_size = len(dict_old[ResultsKey.SLIDE_ID]) - value = np.full(bag_size, fill_value=value) - dict_new[key] = value + if key not in [ResultsKey.CLASS_PROBS, ResultsKey.PROB]: + if isinstance(value, Tensor): + value = value.squeeze(0).to(device).numpy() + if value.ndim == 0: + value = np.full(bag_size, fill_value=value) + dict_new[key] = value + elif key == ResultsKey.CLASS_PROBS: + if isinstance(value, Tensor): + value = value.squeeze(0).to(device).numpy() + for i in range(len(value)): + dict_new[key+str(i)] = np.repeat(value[i], bag_size) return dict_new @staticmethod diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index c0d47a664..a966e4cf5 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -20,7 +20,7 @@ def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: int = 1, select: Tuple = ('lowest_pred', 'highest_att'), slide_col: str = ResultsKey.SLIDE_ID, gt_col: str = ResultsKey.TRUE_LABEL, - attn_col: str = ResultsKey.BAG_ATTN, prob_col: str = ResultsKey.PROB, + attn_col: str = ResultsKey.BAG_ATTN, prob_col: str = ResultsKey.CLASS_PROBS, return_col: str = ResultsKey.IMAGE_PATH) -> List[Tuple[Any, Any, List[Any], List[Any]]]: """ :param results: List that contains slide_level dicts @@ -35,7 +35,7 @@ def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: in :param return_col: column name of the values we want to return for each tile :return: tuple containing the slides id, the slide score, the tile ids, the tiles scores """ - tmp_s = [(results[prob_col][i], i) for i, gt in enumerate(results[gt_col]) if gt == label] # type ignore + tmp_s = [(results[prob_col][i][label], i) for i, gt in enumerate(results[gt_col]) if gt == label] # type ignore if select[0] == 'lowest_pred': tmp_s.sort(reverse=False) elif select[0] == 'highest_pred': @@ -58,12 +58,12 @@ def select_k_tiles(results: Dict, n_tiles: int = 5, n_slides: int = 5, label: in scores.append(results[attn_col][slide_idx][0][t_idx]) # slide_ids are duplicated k_idx.append((results[slide_col][slide_idx][0], - results[prob_col][slide_idx].item(), + results[prob_col][slide_idx], k_tiles, scores)) return k_idx -def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.PROB, +def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.CLASS_PROBS, gt_col: str = ResultsKey.TRUE_LABEL) -> plt.figure: """ :param results: List that contains slide_level dicts @@ -71,20 +71,23 @@ def plot_scores_hist(results: Dict, prob_col: str = ResultsKey.PROB, :param gt_col: column name that contains the true label :return: matplotlib figure of the scores histogram by class """ - pos_scores = [results[prob_col][i][0].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == 1] - neg_scores = [results[prob_col][i][0].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == 0] - fig, ax = plt.subplots() - ax.hist([pos_scores, neg_scores], label=['1', '0'], alpha=0.5) + n_classes = len(results[prob_col][0]) + scores_class = [] + for j in range(n_classes): + scores = [results[prob_col][i][j].cpu().item() for i, gt in enumerate(results[gt_col]) if gt == j] + scores_class.append(scores) + fig, ax = plt.subplots() + ax.hist(scores_class, label=[str(i) for i in range(n_classes)], alpha=0.5) ax.set_xlabel("Predicted Score") ax.legend() return fig -def plot_attention_tiles(slide: str, score: float, paths: List, attn: List, case: str, ncols: int = 5, +def plot_attention_tiles(slide: str, scores: List[float], paths: List, attn: List, case: str, ncols: int = 5, size: Tuple = (10, 10)) -> plt.figure: """ :param slide: slide identifier - :param score: predicted score for the slide + :param scores: predicted scores of each class for the slide :param paths: list of paths to tiles belonging to the slide :param attn: list of scores belonging to the tiles in paths. paths and attn are expected to have the same shape :param case: string used to define the title of the plot e.g. TP @@ -94,7 +97,7 @@ def plot_attention_tiles(slide: str, score: float, paths: List, attn: List, case """ nrows = int(ceil(len(paths) / ncols)) fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=size) - fig.suptitle(f"{case}: {slide} P=%.2f" % score) + fig.suptitle(f"{case}: {slide} P=%.2f" % max(scores)) for i in range(len(paths)): img = load_pil_image(paths[i]) axs.ravel()[i].imshow(img, clim=(0, 255), cmap='gray') diff --git a/InnerEye/ML/Histopathology/utils/naming.py b/InnerEye/ML/Histopathology/utils/naming.py index f499b7b51..04dc01dc9 100644 --- a/InnerEye/ML/Histopathology/utils/naming.py +++ b/InnerEye/ML/Histopathology/utils/naming.py @@ -48,6 +48,7 @@ class ResultsKey(str, Enum): IMAGE_PATH = 'image_path' LOSS = 'loss' PROB = 'prob' + CLASS_PROBS = 'prob_class' PRED_LABEL = 'pred_label' TRUE_LABEL = 'true_label' BAG_ATTN = 'bag_attn' diff --git a/Tests/ML/histopathology/utils/test_metrics_utils.py b/Tests/ML/histopathology/utils/test_metrics_utils.py index b1cc4c1ef..9db87a914 100644 --- a/Tests/ML/histopathology/utils/test_metrics_utils.py +++ b/Tests/ML/histopathology/utils/test_metrics_utils.py @@ -31,6 +31,9 @@ def assert_equal_lists(pred: List, expected: List) -> None: for j, value in enumerate(slide): if type(value) in [int, float]: assert math.isclose(value, expected[i][j], rel_tol=1e-06) + elif (type(value) == Tensor) and (value.ndim >= 1): + for k, idx in enumerate(value): + assert math.isclose(idx, expected[i][j][k], rel_tol=1e-06) elif isinstance(value, List): for k, idx in enumerate(value): if type(idx) in [int, float]: @@ -41,15 +44,20 @@ def assert_equal_lists(pred: List, expected: List) -> None: raise TypeError("Unexpected list composition") -test_dict = {ResultsKey.SLIDE_ID: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], - ResultsKey.IMAGE_PATH: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], - ResultsKey.PROB: [Tensor([0.5]), Tensor([0.7]), Tensor([0.4]), Tensor([1.0])], - ResultsKey.TRUE_LABEL: [0, 1, 1, 1], +test_dict = {ResultsKey.SLIDE_ID: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + ResultsKey.IMAGE_PATH: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], + ResultsKey.CLASS_PROBS: [Tensor([0.6, 0.4]), Tensor([0.3, 0.7]), Tensor([0.6, 0.4]), Tensor([0.0, 1.0]), + Tensor([0.7, 0.3]), Tensor([0.8, 0.2]), Tensor([0.1, 0.9]), Tensor([0.01, 0.99])], + ResultsKey.TRUE_LABEL: [0, 1, 1, 1, 1, 0, 0, 0], ResultsKey.BAG_ATTN: - [Tensor([[0.1, 0.0, 0.2, 0.15]]), + [Tensor([[0.10, 0.00, 0.20, 0.15]]), Tensor([[0.10, 0.18, 0.15, 0.13]]), Tensor([[0.25, 0.23, 0.20, 0.21]]), - Tensor([[0.33, 0.31, 0.37, 0.35]])], + Tensor([[0.33, 0.31, 0.37, 0.35]]), + Tensor([[0.43, 0.01, 0.07, 0.25]]), + Tensor([[0.53, 0.11, 0.17, 0.55]]), + Tensor([[0.63, 0.21, 0.27, 0.05]]), + Tensor([[0.73, 0.31, 0.37, 0.15]])], ResultsKey.TILE_X: [Tensor([200, 200, 424, 424]), Tensor([200, 200, 424, 424]), @@ -64,27 +72,40 @@ def assert_equal_lists(pred: List, expected: List) -> None: def test_select_k_tiles() -> None: - top_tn = select_k_tiles(test_dict, n_slides=1, label=0, n_tiles=2, select=('lowest_pred', 'highest_att')) - assert_equal_lists(top_tn, [(1, 0.5, [3, 4], [Tensor([0.2]), Tensor([0.15])])]) - nslides = 2 ntiles = 2 + # TP + top_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'highest_att')) + bottom_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('highest_pred', 'lowest_att')) + print(top_tp) + assert_equal_lists(top_tp, [(4, Tensor([0.0, 1.0]), [3, 4], [Tensor([0.37]), Tensor([0.35])]), + (2, Tensor([0.3, 0.7]), [2, 3], [Tensor([0.18]), Tensor([0.15])])]) + assert_equal_lists(bottom_tp, [(4, Tensor([0.0, 1.0]), [2, 1], [Tensor([0.31]), Tensor([0.33])]), + (2, Tensor([0.3, 0.7]), [1, 4], [Tensor([0.10]), Tensor([0.13])])]) + + # FN top_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'highest_att')) - bottom_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, - select=('lowest_pred', 'lowest_att')) - assert_equal_lists(top_fn, [(3, 0.4, [1, 2], [Tensor([0.25]), Tensor([0.23])]), - (2, 0.7, [2, 3], [Tensor([0.18]), Tensor([0.15])])]) - assert_equal_lists(bottom_fn, [(3, 0.4, [3, 4], [Tensor([0.20]), Tensor([0.21])]), - (2, 0.7, [1, 4], [Tensor([0.10]), Tensor([0.13])])]) - - top_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, - select=('highest_pred', 'highest_att')) - bottom_tp = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, - select=('highest_pred', 'lowest_att')) - assert_equal_lists(top_tp, [(4, 1.0, [3, 4], [Tensor([0.37]), Tensor([0.35])]), - (2, 0.7, [2, 3], [Tensor([0.18]), Tensor([0.15])])]) - assert_equal_lists(bottom_tp, [(4, 1.0, [2, 1], [Tensor([0.31]), Tensor([0.33])]), - (2, 0.7, [1, 4], [Tensor([0.10]), Tensor([0.13])])]) + bottom_fn = select_k_tiles(test_dict, n_slides=nslides, label=1, n_tiles=ntiles, select=('lowest_pred', 'lowest_att')) + assert_equal_lists(top_fn, [(5, Tensor([0.7, 0.3]), [1, 4], [Tensor([0.43]), Tensor([0.25])]), + (3, Tensor([0.6, 0.4]), [1, 2], [Tensor([0.25]), Tensor([0.23])])]) + assert_equal_lists(bottom_fn, [(5, Tensor([0.7, 0.3]), [2, 3], [Tensor([0.01]), Tensor([0.07])]), + (3, Tensor([0.6, 0.4]), [3, 4], [Tensor([0.20]), Tensor([0.21])])]) + + # TN + top_tn = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('highest_pred', 'highest_att')) + bottom_tn = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('highest_pred', 'lowest_att')) + assert_equal_lists(top_tn, [(6, Tensor([0.8, 0.2]), [4, 1], [Tensor([0.55]), Tensor([0.53])]), + (1, Tensor([0.6, 0.4]), [3, 4], [Tensor([0.2]), Tensor([0.15])])]) + assert_equal_lists(bottom_tn, [(6, Tensor([0.8, 0.2]), [2, 3], [Tensor([0.11]), Tensor([0.17])]), + (1, Tensor([0.6, 0.4]), [2, 1], [Tensor([0.00]), Tensor([0.10])])]) + + # FP + top_fp = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('lowest_pred', 'highest_att')) + bottom_fp = select_k_tiles(test_dict, n_slides=nslides, label=0, n_tiles=ntiles, select=('lowest_pred', 'lowest_att')) + assert_equal_lists(top_fp, [(8, Tensor([0.01, 0.99]), [1, 3], [Tensor([0.73]), Tensor([0.37])]), + (7, Tensor([0.1, 0.9]), [1, 3], [Tensor([0.63]), Tensor([0.27])])]) + assert_equal_lists(bottom_fp, [(8, Tensor([0.01, 0.99]), [4, 2], [Tensor([0.15]), Tensor([0.31])]), + (7, Tensor([0.1, 0.9]), [4, 2], [Tensor([0.05]), Tensor([0.21])])]) @pytest.mark.skipif(is_windows(), reason="Rendering is different on Windows") diff --git a/Tests/ML/test_data/histo_heatmaps/score_hist.png b/Tests/ML/test_data/histo_heatmaps/score_hist.png index bced47d8b..bdc6d5b59 100644 --- a/Tests/ML/test_data/histo_heatmaps/score_hist.png +++ b/Tests/ML/test_data/histo_heatmaps/score_hist.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ca95c0017d0a51d75d118e54f21c1e907b3d90dcca822b23622e369267907198 -size 17057 +oid sha256:6ddc430ffcade51a072e9452833143840b1e5726148fd850ad3f370f1315bb32 +size 20452 diff --git a/hi-ml b/hi-ml index 30854eae4..0250715c5 160000 --- a/hi-ml +++ b/hi-ml @@ -1 +1 @@ -Subproject commit 30854eae4fd27776be9f0105099ddba663ef3eb5 +Subproject commit 0250715c5ac1ef09227b51388df44b568a496f65