diff --git a/.gitignore b/.gitignore index 120d80cf5..eeed8fae8 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ wheels/ .installed.cfg *.egg MANIFEST +packages-microsoft-prod.deb # PyInstaller # Usually these files are written by a python script from a template diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d0bb11cc..8c5612acf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ created. ## Upcoming ### Added +- ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference +module in the test data without or partial ground truth files. - ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference. - ([#492](https://github.com/microsoft/InnerEye-DeepLearning/pull/492)) Adding capability for regression tests for test jobs that run in AzureML. diff --git a/InnerEye/ML/config.py b/InnerEye/ML/config.py index 5a8e25d44..8e3aab58c 100644 --- a/InnerEye/ML/config.py +++ b/InnerEye/ML/config.py @@ -474,17 +474,25 @@ class SegmentationModelBase(ModelConfigBase): is_plotting_enabled: bool = param.Boolean(True, doc="If true, various overview plots with results are generated " "during model evaluation. Set to False if you see " "non-deterministic pull request build failures.") + show_patch_sampling: int = param.Integer(1, bounds=(0, None), doc="Number of patients from the training set for which the effect of" "patch sampling will be shown. Nifti images and thumbnails for each" "of the first N subjects in the training set will be " "written to the outputs folder.") + #: If true an error is raised in InnerEye.ML.utils.io_util.load_labels_from_dataset_source if the labels are not #: mutually exclusive. Some loss functions (e.g. SoftDice) may produce results on overlapping labels, but others (e.g. #: FocalLoss) will fail with a cryptic error message. Set to false if you are sure that you want to use labels that #: are not mutually exclusive. check_exclusive: bool = param.Boolean(True, doc="Raise an error if the segmentation labels are not mutually exclusive.") + allow_incomplete_labels: bool = param.Boolean( + default=False, + doc="If False, the default, then test patient data must include all of the ground truth labels. If true then " + "some test patient data with missing ground truth data is allowed and will be reflected in the patient " + "counts in the metrics and report.") + def __init__(self, center_size: Optional[TupleInt3] = None, inference_stride_size: Optional[TupleInt3] = None, min_l_rate: float = 0, diff --git a/InnerEye/ML/dataset/full_image_dataset.py b/InnerEye/ML/dataset/full_image_dataset.py index aa3bf2dff..63883c006 100644 --- a/InnerEye/ML/dataset/full_image_dataset.py +++ b/InnerEye/ML/dataset/full_image_dataset.py @@ -216,7 +216,7 @@ def __init__(self, args: SegmentationModelBase, data_frame: pd.DataFrame, full_image_sample_transforms: Optional[Compose3D[Sample]] = None): super().__init__(args, data_frame) self.full_image_sample_transforms = full_image_sample_transforms - + self.allow_incomplete_labels = args.allow_incomplete_labels # Check base_path assert self.args.local_dataset is not None if not self.args.local_dataset.is_dir(): @@ -250,7 +250,8 @@ def _extension_from_df_file_paths(file_paths: List[str]) -> str: def get_samples_at_index(self, index: int) -> List[Sample]: # load the channels into memory ds = self.dataset_sources[self.dataset_indices[index]] - samples = [io_util.load_images_from_dataset_source(dataset_source=ds, check_exclusive=self.args.check_exclusive)] # type: ignore + samples = [io_util.load_images_from_dataset_source(dataset_source=ds, + check_exclusive=self.args.check_exclusive)] # type: ignore return [Compose3D.apply(self.full_image_sample_transforms, x) for x in samples] def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]: @@ -259,34 +260,40 @@ def _load_dataset_sources(self) -> Dict[str, PatientDatasetSource]: local_dataset_root_folder=self.args.local_dataset, image_channels=self.args.image_channels, ground_truth_channels=self.args.ground_truth_ids, - mask_channel=self.args.mask_id - ) + mask_channel=self.args.mask_id, + allow_incomplete_labels=self.allow_incomplete_labels) def convert_channels_to_file_paths(channels: List[str], rows: pd.DataFrame, local_dataset_root_folder: Path, - patient_id: str) -> Tuple[List[Path], str]: + patient_id: str, + allow_incomplete_labels: bool = False) -> Tuple[List[Optional[Path]], str]: """ - Returns: 1) The full path for files specified in the training, validation and testing datasets, and - 2) Missing channels or missing files. + Returns: 1) A list of path file objects specified in the training, validation and testing datasets, and + 2) a string with description of missing channels, files and more than one channel per patient. :param channels: channel type defined in the configuration file :param rows: Input Pandas dataframe object containing subjectIds, path of local dataset, channel information :param local_dataset_root_folder: Root directory which points to the local dataset :param patient_id: string which contains subject identifier + :param allow_incomplete_labels: boolean flag. If false, all ground truth files must be provided. If true, ground + truth files are optional """ - paths: List[Path] = [] - failed_channel_info: str = '' + paths: List[Optional[Path]] = [] + failed_channel_info = '' for channel_id in channels: row = rows.loc[rows[CSV_CHANNEL_HEADER] == channel_id] - if len(row) == 0: + if len(row) == 0 and not allow_incomplete_labels: failed_channel_info += f"Patient {patient_id} does not have channel '{channel_id}'" + os.linesep + elif len(row) == 0 and allow_incomplete_labels: + # Keeps track of missing channels order + paths.append(None) elif len(row) > 1: failed_channel_info += f"Patient {patient_id} has more than one entry for channel '{channel_id}'" + \ os.linesep - else: + elif len(row) == 1: image_path = local_dataset_root_folder / row[CSV_PATH_HEADER].values[0] if not image_path.is_file(): failed_channel_info += f"Patient {patient_id}, file {image_path} does not exist" + os.linesep @@ -300,7 +307,8 @@ def load_dataset_sources(dataframe: pd.DataFrame, local_dataset_root_folder: Path, image_channels: List[str], ground_truth_channels: List[str], - mask_channel: Optional[str]) -> Dict[str, PatientDatasetSource]: + mask_channel: Optional[str], + allow_incomplete_labels: bool = False) -> Dict[str, PatientDatasetSource]: """ Prepares a patient-to-images mapping from a dataframe read directly from a dataset CSV file. The dataframe contains per-patient per-channel image information, relative to a root directory. @@ -311,6 +319,8 @@ def load_dataset_sources(dataframe: pd.DataFrame, :param image_channels: The names of the image channels that should be used in the result. :param ground_truth_channels: The names of the ground truth channels that should be used in the result. :param mask_channel: The name of the mask channel that should be used in the result. This can be None. + :param allow_incomplete_labels: Boolean flag. If false, all ground truth files must be provided. If true, ground + truth files are optional. Default value is false. :return: A dictionary mapping from an integer subject ID to a PatientDatasetSource. """ expected_headers = {CSV_SUBJECT_HEADER, CSV_PATH_HEADER, CSV_CHANNEL_HEADER} @@ -328,16 +338,19 @@ def load_dataset_sources(dataframe: pd.DataFrame, def get_mask_channel_or_default() -> Optional[Path]: if mask_channel is None: return None + paths = get_paths_for_channel_ids(channels=[mask_channel], allow_incomplete_labels_flag=allow_incomplete_labels) + if len(paths) == 0: + return None else: - return get_paths_for_channel_ids(channels=[mask_channel])[0] + return paths[0] - def get_paths_for_channel_ids(channels: List[str]) -> List[Path]: + def get_paths_for_channel_ids(channels: List[str], allow_incomplete_labels_flag: bool) -> List[Optional[Path]]: if len(set(channels)) < len(channels): raise ValueError(f"ids have duplicated entries: {channels}") rows = dataframe.loc[dataframe[CSV_SUBJECT_HEADER] == patient_id] # converts channels to paths and makes second sanity check for channel data paths, failed_channel_info = convert_channels_to_file_paths(channels, rows, local_dataset_root_folder, - patient_id) + patient_id, allow_incomplete_labels_flag) if failed_channel_info: raise ValueError(failed_channel_info) @@ -349,9 +362,11 @@ def get_paths_for_channel_ids(channels: List[str]) -> List[Path]: metadata = PatientMetadata.from_dataframe(dataframe, patient_id) dataset_sources[patient_id] = PatientDatasetSource( metadata=metadata, - image_channels=get_paths_for_channel_ids(channels=image_channels), # type: ignore + image_channels=get_paths_for_channel_ids(channels=image_channels, # type: ignore + allow_incomplete_labels_flag=False), mask_channel=get_mask_channel_or_default(), - ground_truth_channels=get_paths_for_channel_ids(channels=ground_truth_channels) # type: ignore - ) + ground_truth_channels=get_paths_for_channel_ids(channels=ground_truth_channels, # type: ignore + allow_incomplete_labels_flag=allow_incomplete_labels), + allow_incomplete_labels=allow_incomplete_labels) return dataset_sources diff --git a/InnerEye/ML/dataset/sample.py b/InnerEye/ML/dataset/sample.py index b556c215e..a233e01ec 100644 --- a/InnerEye/ML/dataset/sample.py +++ b/InnerEye/ML/dataset/sample.py @@ -129,9 +129,10 @@ class PatientDatasetSource(SampleBase): Dataset source locations for channels associated with a given patient in a particular dataset. """ image_channels: List[PathOrString] - ground_truth_channels: List[PathOrString] + ground_truth_channels: List[Optional[PathOrString]] mask_channel: Optional[PathOrString] metadata: PatientMetadata + allow_incomplete_labels: Optional[bool] = False def __post_init__(self) -> None: # make sure all properties are populated @@ -139,9 +140,13 @@ def __post_init__(self) -> None: if not self.image_channels: raise ValueError("image_channels cannot be empty") + if not self.ground_truth_channels: raise ValueError("ground_truth_channels cannot be empty") + if self.ground_truth_channels.count(None) > 0 and not self.allow_incomplete_labels: + raise ValueError("all ground_truth_channels must be provided") + @dataclass(frozen=True) class Sample(SampleBase): diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 105dc6260..417e02d1b 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -161,11 +161,13 @@ def setup(self) -> None: unique_ids = set(split_data[CSV_SUBJECT_HEADER]) for patient_id in unique_ids: rows = split_data.loc[split_data[CSV_SUBJECT_HEADER] == patient_id] + allow_incomplete_labels = self.config.allow_incomplete_labels # type: ignore # Converts channels from data frame to file paths and gets errors if any __, failed_channel_info = convert_channels_to_file_paths(all_channels, - rows, - local_dataset_root_folder, - patient_id) + rows, + local_dataset_root_folder, + patient_id, + allow_incomplete_labels) full_failed_channel_info += failed_channel_info if full_failed_channel_info: diff --git a/InnerEye/ML/metrics.py b/InnerEye/ML/metrics.py index 3c7471274..a87dde368 100644 --- a/InnerEye/ML/metrics.py +++ b/InnerEye/ML/metrics.py @@ -12,6 +12,7 @@ import SimpleITK as sitk import numpy as np +from numpy.core.numeric import NaN import torch import torch.nn.functional as F from azureml.core import Run @@ -21,12 +22,13 @@ from InnerEye.Common.type_annotations import DictStrFloat, TupleFloat3 from InnerEye.ML.common import ModelExecutionMode from InnerEye.ML.config import BACKGROUND_CLASS_NAME -from InnerEye.ML.metrics_dict import DataframeLogger, INTERNAL_TO_LOGGING_COLUMN_NAMES, MetricsDict, \ - ScalarMetricsDict +from InnerEye.ML.metrics_dict import (DataframeLogger, INTERNAL_TO_LOGGING_COLUMN_NAMES, MetricsDict, + ScalarMetricsDict) from InnerEye.ML.scalar_config import ScalarLoss from InnerEye.ML.utils.image_util import binaries_from_multi_label_array, is_binary_array from InnerEye.ML.utils.io_util import reverse_tuple_float3 -from InnerEye.ML.utils.metrics_util import binary_classification_accuracy, mean_absolute_error, r2_score +from InnerEye.ML.utils.metrics_util import (binary_classification_accuracy, mean_absolute_error, + r2_score, is_missing_ground_truth) from InnerEye.ML.utils.ml_util import check_size_matches from InnerEye.ML.utils.sequence_utils import get_masked_model_outputs_and_labels @@ -56,7 +58,7 @@ class InferenceMetricsForSegmentation(InferenceMetrics): """ Stores metrics for segmentation models, per execution mode and epoch. """ - data_split: ModelExecutionMode + execution_mode: ModelExecutionMode metrics: float def get_metrics_log_key(self) -> str: @@ -64,7 +66,7 @@ def get_metrics_log_key(self) -> str: Gets a string name for logging the metrics specific to the execution mode (train, val, test) :return: """ - return f"InferenceMetrics_{self.data_split.value}" + return f"InferenceMetrics_{self.execution_mode.value}" def log_metrics(self, run_context: Run = None) -> None: """ @@ -230,9 +232,10 @@ def calculate_metrics_per_class(segmentation: np.ndarray, Calculate the dice for all foreground structures (the background class is completely ignored). Returns a MetricsDict with metrics for each of the foreground structures. Metrics are NaN if both ground truth and prediction are all zero for a class. + If first element of a ground truth image channel is NaN, the image is flagged as NaN and not use. :param ground_truth_ids: The names of all foreground classes. :param segmentation: predictions multi-value array with dimensions: [Z x Y x X] - :param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X] + :param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X]. :param voxel_spacing: voxel_spacing in 3D Z x Y x X :param patient_id: for logging """ @@ -242,15 +245,34 @@ def calculate_metrics_per_class(segmentation: np.ndarray, f"the label tensor indicates that there are {number_of_classes - 1} classes.") binaries = binaries_from_multi_label_array(segmentation, number_of_classes) - all_classes_are_binary = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])] - if not np.all(all_classes_are_binary): + binary_classes = [is_binary_array(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])] + + # If ground truth image is nan, then will not be used for metrics computation. + nan_images = [is_missing_ground_truth(ground_truth[label_id]) for label_id in range(ground_truth.shape[0])] + + # Compares element-wise if not binary then nan and checks all elements are True. + assert np.all(np.array(binary_classes) == ~np.array(nan_images)) + + # Validates that all binary images should be 0 or 1 + if not np.all(np.array(binary_classes)[~np.array(nan_images)]): raise ValueError("Ground truth values should be 0 or 1") overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter() hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter() metrics = MetricsDict(hues=ground_truth_ids) + + def add_metric(metric_type: MetricType, value: float) -> None: + metrics.add_metric(metric_type, value, skip_nan_when_averaging=True, hue=ground_truth_ids[i - 1]) + for i, prediction in enumerate(binaries): + # Skip if background image if i == 0: continue + # Skip but record if nan_image + elif nan_images[i]: + add_metric(MetricType.DICE, NaN) + add_metric(MetricType.HAUSDORFF_mm, NaN) + add_metric(MetricType.MEAN_SURFACE_DIST_mm, NaN) + continue check_size_matches(prediction, ground_truth[i], arg1_name="prediction", arg2_name="ground_truth") if not is_binary_array(prediction): raise ValueError("Predictions values should be 0 or 1") @@ -280,10 +302,6 @@ def calculate_metrics_per_class(segmentation: np.ndarray, except Exception as e: logging.warning(f"Cannot calculate mean distance for structure {i} of patient {patient_id}: {e}") logging.debug(f"Patient {patient_id}, class {i} has Dice score {dice}") - - def add_metric(metric_type: MetricType, value: float) -> None: - metrics.add_metric(metric_type, value, skip_nan_when_averaging=True, hue=ground_truth_ids[i - 1]) - add_metric(MetricType.DICE, dice) add_metric(MetricType.HAUSDORFF_mm, hausdorff_distance) add_metric(MetricType.MEAN_SURFACE_DIST_mm, mean_surface_distance) diff --git a/InnerEye/ML/model_testing.py b/InnerEye/ML/model_testing.py index f9f709321..9b6c7521e 100644 --- a/InnerEye/ML/model_testing.py +++ b/InnerEye/ML/model_testing.py @@ -25,7 +25,7 @@ from InnerEye.ML.dataset.sample import PatientMetadata, Sample from InnerEye.ML.metrics import InferenceMetrics, InferenceMetricsForClassification, InferenceMetricsForSegmentation, \ compute_scalar_metrics -from InnerEye.ML.metrics_dict import DataframeLogger, MetricsDict, ScalarMetricsDict, SequenceMetricsDict +from InnerEye.ML.metrics_dict import DataframeLogger, FloatOrInt, MetricsDict, ScalarMetricsDict, SequenceMetricsDict from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.pipelines.ensemble import EnsemblePipeline from InnerEye.ML.pipelines.inference import FullImageInferencePipelineBase, InferencePipeline, InferencePipelineBase @@ -76,16 +76,17 @@ def model_test(config: ModelConfigBase, def segmentation_model_test(config: SegmentationModelBase, - data_split: ModelExecutionMode, + execution_mode: ModelExecutionMode, checkpoint_handler: CheckpointHandler, model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> InferenceMetricsForSegmentation: """ The main testing loop for segmentation models. It loads the model and datasets, then proceeds to test the model for all requested checkpoints. :param config: The arguments object which has a valid random seed attribute. - :param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed. - :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization - :param model_proc: whether we are testing an ensemble or single model + :param execution_mode: Indicates which of the 3 sets (training, test, or validation) is being processed. + :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization. + :param model_proc: Whether we are testing an ensemble or single model. + :param patient_id: String which contains subject identifier. :return: InferenceMetric object that contains metrics related for all of the checkpoint epochs. """ checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test() @@ -93,12 +94,12 @@ def segmentation_model_test(config: SegmentationModelBase, if not checkpoints_to_test: raise ValueError("There were no checkpoints available for model testing.") - epoch_results_folder = config.outputs_folder / get_best_epoch_results_path(data_split, model_proc) + epoch_results_folder = config.outputs_folder / get_best_epoch_results_path(execution_mode, model_proc) # save the datasets.csv used config.write_dataset_files(root=epoch_results_folder) - epoch_and_split = f"{data_split.value} set" + epoch_and_split = f"{execution_mode.value} set" epoch_dice_per_image = segmentation_model_test_epoch(config=copy.deepcopy(config), - data_split=data_split, + execution_mode=execution_mode, checkpoint_paths=checkpoints_to_test, results_folder=epoch_results_folder, epoch_and_split=epoch_and_split) @@ -110,13 +111,13 @@ def segmentation_model_test(config: SegmentationModelBase, logging.info(f"Mean Dice: {epoch_average_dice:4f}") if model_proc == ModelProcessing.ENSEMBLE_CREATION: # For the upload, we want the path without the "OTHER_RUNS/ENSEMBLE" prefix. - name = str(get_best_epoch_results_path(data_split, ModelProcessing.DEFAULT)) + name = str(get_best_epoch_results_path(execution_mode, ModelProcessing.DEFAULT)) PARENT_RUN_CONTEXT.upload_folder(name=name, path=str(epoch_results_folder)) - return InferenceMetricsForSegmentation(data_split=data_split, metrics=result) + return InferenceMetricsForSegmentation(execution_mode=execution_mode, metrics=result) def segmentation_model_test_epoch(config: SegmentationModelBase, - data_split: ModelExecutionMode, + execution_mode: ModelExecutionMode, checkpoint_paths: List[Path], results_folder: Path, epoch_and_split: str) -> Optional[List[float]]: @@ -126,8 +127,8 @@ def segmentation_model_test_epoch(config: SegmentationModelBase, where the average is taken across all non-background structures in the image. :param checkpoint_paths: Checkpoint paths to run inference on. :param config: The arguments which specify all required information. - :param data_split: Is the model evaluated on train, test, or validation set? - :param results_folder: The folder where to store the results + :param execution_mode: Is the model evaluated on train, test, or validation set? + :param results_folder: The folder where to store the results. :param epoch_and_split: A string that should uniquely identify the epoch and the data split (train/val/test). :raises TypeError: If the arguments are of the wrong type. :raises ValueError: When there are issues loading the model. @@ -136,8 +137,8 @@ def segmentation_model_test_epoch(config: SegmentationModelBase, ml_util.set_random_seed(config.get_effective_random_seed(), "Model testing") results_folder.mkdir(exist_ok=True) - test_dataframe = config.get_dataset_splits()[data_split] - test_csv_path = results_folder / STORED_CSV_FILE_NAMES[data_split] + test_dataframe = config.get_dataset_splits()[execution_mode] + test_csv_path = results_folder / STORED_CSV_FILE_NAMES[execution_mode] test_dataframe.to_csv(path_or_buf=test_csv_path, index=False) logging.info("Results directory: {}".format(results_folder)) logging.info(f"Starting evaluation of model {config.model_name} on {epoch_and_split}") @@ -145,7 +146,7 @@ def segmentation_model_test_epoch(config: SegmentationModelBase, # Write the dataset id and ground truth ids into the results folder store_run_information(results_folder, config.azure_dataset_id, config.ground_truth_ids, config.image_channels) - ds = config.get_torch_dataset_for_inference(data_split) + ds = config.get_torch_dataset_for_inference(execution_mode) inference_pipeline = create_inference_pipeline(config=config, checkpoint_paths=checkpoint_paths) @@ -182,25 +183,9 @@ def segmentation_model_test_epoch(config: SegmentationModelBase, results_folder=results_folder), range(len(ds))) - average_dice = list() - metrics_writer = MetricsPerPatientWriter() - for (patient_metadata, metrics_for_patient) in pool_outputs: - # Add the Dice score for the foreground classes, stored in the default hue - metrics.add_average_foreground_dice(metrics_for_patient) - average_dice.append(metrics_for_patient.get_single_metric(MetricType.DICE)) - # Structure names does not include the background class (index 0) - for structure_name in config.ground_truth_ids: - dice_for_struct = metrics_for_patient.get_single_metric(MetricType.DICE, hue=structure_name) - hd_for_struct = metrics_for_patient.get_single_metric(MetricType.HAUSDORFF_mm, hue=structure_name) - md_for_struct = metrics_for_patient.get_single_metric(MetricType.MEAN_SURFACE_DIST_mm, hue=structure_name) - metrics_writer.add(patient=str(patient_metadata.patient_id), - structure=structure_name, - dice=dice_for_struct, - hausdorff_distance_mm=hd_for_struct, - mean_distance_mm=md_for_struct) - + metrics_writer, average_dice = populate_metrics_writer(pool_outputs, config) metrics_writer.to_csv(results_folder / SUBJECT_METRICS_FILE_NAME) - metrics_writer.save_aggregates_to_csv(results_folder / METRICS_AGGREGATES_FILE) + metrics_writer.save_aggregates_to_csv(results_folder / METRICS_AGGREGATES_FILE, config.allow_incomplete_labels) if config.is_plotting_enabled: plt.figure() boxplot_per_structure(metrics_writer.to_data_frame(), @@ -231,6 +216,7 @@ def evaluate_model_predictions(process_id: int, """ sample = dataset.get_samples_at_index(index=process_id)[0] logging.info(f"Evaluating predictions for patient {sample.patient_id}") + patient_results_folder = get_patient_results_folder(results_folder, sample.patient_id) segmentation = load_nifti_image(patient_results_folder / DEFAULT_RESULT_IMAGE_NAME).image metrics_per_class = metrics.calculate_metrics_per_class(segmentation, @@ -248,6 +234,35 @@ def evaluate_model_predictions(process_id: int, return sample.metadata, metrics_per_class +def populate_metrics_writer( + model_prediction_evaluations: List[Tuple[PatientMetadata, MetricsDict]], + config: SegmentationModelBase) -> Tuple[MetricsPerPatientWriter, List[FloatOrInt]]: + """ + Populate a MetricsPerPatientWriter with the metrics for each patient + :param model_prediction_evaluations: The list of PatientMetadata/MetricsDict tuples obtained + from evaluate_model_predictions + :param config: The SegmentationModelBase config from which we read the ground_truth_ids + :returns: A new MetricsPerPatientWriter and a list of foreground DICE score averages + """ + average_dice: List[FloatOrInt] = [] + metrics_writer = MetricsPerPatientWriter() + for (patient_metadata, metrics_for_patient) in model_prediction_evaluations: + # Add the Dice score for the foreground classes, stored in the default hue + metrics.add_average_foreground_dice(metrics_for_patient) + average_dice.append(metrics_for_patient.get_single_metric(MetricType.DICE)) + # Structure names does not include the background class (index 0) + for structure_name in config.ground_truth_ids: + dice_for_struct = metrics_for_patient.get_single_metric(MetricType.DICE, hue=structure_name) + hd_for_struct = metrics_for_patient.get_single_metric(MetricType.HAUSDORFF_mm, hue=structure_name) + md_for_struct = metrics_for_patient.get_single_metric(MetricType.MEAN_SURFACE_DIST_mm, hue=structure_name) + metrics_writer.add(patient=str(patient_metadata.patient_id), + structure=structure_name, + dice=dice_for_struct, + hausdorff_distance_mm=hd_for_struct, + mean_distance_mm=md_for_struct) + return metrics_writer, average_dice + + def get_patient_results_folder(results_folder: Path, patient_id: int) -> Path: """ Gets a folder name that will contain all results for a given patient, like root/017 for patient 17. diff --git a/InnerEye/ML/plotting.py b/InnerEye/ML/plotting.py index d0412849f..b467eeb2c 100644 --- a/InnerEye/ML/plotting.py +++ b/InnerEye/ML/plotting.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np +import sys from matplotlib import colors from matplotlib.pyplot import Axes @@ -15,6 +16,7 @@ from InnerEye.ML.photometric_normalization import PhotometricNormalization from InnerEye.ML.utils import plotting_util from InnerEye.ML.utils.image_util import binaries_from_multi_label_array, get_largest_z_slice +from InnerEye.ML.utils.metrics_util import is_missing_ground_truth from InnerEye.ML.utils.ml_util import check_size_matches from InnerEye.ML.utils.surface_distance_utils import Plane, extract_border @@ -98,6 +100,8 @@ def resize_and_save(width_inch: int, height_inch: int, filename: PathOrString, d """ fig = plt.gcf() fig.set_size_inches(width_inch, height_inch) + # Workaround for Exception in Tkinter callback + fig.canvas.start_event_loop(sys.float_info.min) plt.savefig(filename, dpi=dpi, bbox_inches='tight', pad_inches=0.1) @@ -303,6 +307,10 @@ def plot_contours_for_all_classes(sample: Sample, if class_index == 0: continue ground_truth = sample.labels[class_index, ...] + + if is_missing_ground_truth(ground_truth): + continue + largest_gt_slice = get_largest_z_slice(ground_truth) labels_at_largest_gt = ground_truth[largest_gt_slice] segmentation_at_largest_gt = binary[largest_gt_slice, ...] diff --git a/InnerEye/ML/utils/dataset_util.py b/InnerEye/ML/utils/dataset_util.py index b2275f4a0..2e9154db0 100644 --- a/InnerEye/ML/utils/dataset_util.py +++ b/InnerEye/ML/utils/dataset_util.py @@ -201,7 +201,9 @@ def add_label_stats_to_dataframe(input_dataframe: pd.DataFrame, overlap_stats = metrics_util.get_label_overlap_stats(labels=labels[1:, ...], label_names=target_label_names) - header = io_util.load_nifti_image(dataset_sources[subject_id].ground_truth_channels[0]).header + ground_truth_channel = dataset_sources[subject_id].ground_truth_channels[0] + assert ground_truth_channel is not None + header = io_util.load_nifti_image(ground_truth_channel).header volume_stats = metrics_util.get_label_volume(labels=labels[1:, ...], label_names=target_label_names, label_spacing=header.spacing) diff --git a/InnerEye/ML/utils/io_util.py b/InnerEye/ML/utils/io_util.py index 705769d38..02faeafb6 100644 --- a/InnerEye/ML/utils/io_util.py +++ b/InnerEye/ML/utils/io_util.py @@ -29,6 +29,7 @@ from InnerEye.ML.utils.hdf5_util import HDF5Object from InnerEye.ML.utils.image_util import ImageDataType, ImageHeader, check_array_range, get_center_crop, \ get_unit_image_header, is_binary_array +from InnerEye.ML.utils.metrics_util import is_missing_ground_truth from InnerEye.ML.utils.transforms import LinearTransform, get_range_for_window_level RESULTS_POSTERIOR_FILE_NAME_PREFIX = "posterior_" @@ -412,24 +413,42 @@ def load_image_in_known_formats(file: Path, raise ValueError(f"Unsupported image file type for path {file}") -def load_labels_from_dataset_source(dataset_source: PatientDatasetSource, check_exclusive: bool = True) -> np.ndarray: +def load_labels_from_dataset_source(dataset_source: PatientDatasetSource, check_exclusive: bool = True, + image_size: Optional[Tuple[int]] = None) -> np.ndarray: """ Load labels containing segmentation binary labels in one-hot-encoding. In the future, this function will be used to load global class and non-imaging information as well. + :type image_size: Image size, tuple of integers. :param dataset_source: The dataset source for which channels are to be loaded into memory. - :param check_exclusive: Check that the labels are mutually exclusive (defaults to True) + :param check_exclusive: Check that the labels are mutually exclusive (defaults to True). :return: A label sample object containing ground-truth information. """ - labels = np.stack( - [load_image(gt, ImageDataType.SEGMENTATION.value).image for gt in dataset_source.ground_truth_channels]) - if check_exclusive and (sum(labels) > 1.).any(): # type: ignore + if dataset_source.ground_truth_channels.count(None) > 0: + assert image_size is not None + + label_list = [] + # label_list keeps track of missing ground truth channels + for gt in dataset_source.ground_truth_channels: + if gt is None: + label_list.append(np.full(image_size, np.NAN, ImageDataType)) + else: + label_list.append(load_image(gt, ImageDataType.SEGMENTATION.value).image) + labels = np.stack(label_list) + + # If ground truth image is nan, then will not be used to check check_exclusive + # Image is nan, if voxel at index [0, 0, 0] is NaN + not_nan_label_images = [labels[label_id] for label_id in range(labels.shape[0]) + if not is_missing_ground_truth(labels[label_id])] + + if check_exclusive and (sum(np.array(not_nan_label_images)) > 1.).any(): # type: ignore raise ValueError(f'The labels for patient {dataset_source.metadata.patient_id} are not mutually exclusive. ' 'Some loss functions (e.g. SoftDice) may produce results on overlapping labels, while others ' - '(e.g. FocalLoss) will fail. If you are sure that you want to use labels that are not ' - 'mutually exclusive, then re-run with the check_exclusive flag set to false in the model ' - 'config. Note that this is the first error encountered, other patients may also have ' + '(e.g. FocalLoss) will fail. ' + 'If you are sure that you want to use mutually exclusive labels, ' + 'then re-run with the check_exclusive flag set to false in the settings file. ' + 'Note that this is the first error encountered, other samples/patients may also have ' 'overlapping labels.') # Add the background binary map @@ -502,7 +521,8 @@ def load_images_from_dataset_source(dataset_source: PatientDatasetSource, check_ # create raw sample to return metadata = copy(dataset_source.metadata) metadata.image_header = images[0].header - labels = load_labels_from_dataset_source(dataset_source, check_exclusive=check_exclusive) + labels = load_labels_from_dataset_source(dataset_source, check_exclusive=check_exclusive, image_size=image[0].shape) + return Sample(image=image, labels=labels, mask=mask, diff --git a/InnerEye/ML/utils/metrics_util.py b/InnerEye/ML/utils/metrics_util.py index d3bc98a5b..63a812dd1 100644 --- a/InnerEye/ML/utils/metrics_util.py +++ b/InnerEye/ML/utils/metrics_util.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------------------ from functools import reduce from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Type, Union import numpy as np import pandas as pd @@ -36,7 +36,7 @@ def add(self, hausdorff_distance_mm: float, mean_distance_mm: float) -> None: """ - Adds a Dice score, Mean nad Hausdorff Distances for a patient + structure combination to the present object. + Adds a Dice score, Mean and Hausdorff Distances for a patient + structure combination to the present object. :param patient: The name of the patient. :param structure: The structure that is predicted for. @@ -63,22 +63,37 @@ def to_csv(self, file_name: Path) -> None: del sorted_by_dice[dice_numeric] sorted_by_dice.to_csv(file_name, index=False, float_format=self.float_format) - def save_aggregates_to_csv(self, file_path: Path) -> None: + def save_aggregates_to_csv(self, file_path: Path, allow_incomplete_labels: bool = False) -> None: """ Writes the per-structure aggregate Dice scores (mean, median, and others) to a CSV file. The aggregates are those that are output by the Dataframe 'describe' method. :param file_path: The name of the file to write to. + :param allow_incomplete_labels: boolean flag. If false, all ground truth files must be provided. + If true, ground truth files are optional and we add a total_patients count column for easy + comparison. (Defaults to False.) """ stats_columns = ['mean', 'std', 'min', 'max'] # get aggregates for all metrics - aggregates = self.to_data_frame().groupby(MetricsFileColumns.Structure.value).describe() + df = self.to_data_frame() + aggregates = df.groupby(MetricsFileColumns.Structure.value).describe() + + total_num_patients_column_name = f"total_{MetricsFileColumns.Patient.value}".lower() + if not total_num_patients_column_name.endswith("s"): + total_num_patients_column_name += "s" def filter_rename_metric_columns(_metric_column: str, is_count_column: bool = False) -> pd.DataFrame: _columns = ["count"] + stats_columns if is_count_column else stats_columns _df = aggregates[_metric_column][_columns] - _columns_to_rename = [x for x in _df.columns if x != "count"] + if is_count_column and allow_incomplete_labels: + # For this condition we add a total_patient count column so that readers can make + # more sense of aggregated metrics where some patients were missing the label (i.e. + # partial ground truth). + num_subjects = len(pd.unique(df[MetricsFileColumns.Patient.value])) + _df[total_num_patients_column_name] = num_subjects + _df = _df[["count", total_num_patients_column_name] + stats_columns] + _columns_to_rename = [x for x in _df.columns if x != "count" and x != total_num_patients_column_name] return _df.rename(columns={k: f"{_metric_column}_{k}" for k in _columns_to_rename}) def _merge_df(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame: @@ -100,14 +115,15 @@ def to_data_frame(self) -> DataFrame: # slow, and should be avoided. Hence, work with dictionary as long as possible, and only finally # convert to a DataFrame. - # dtype is specified as (an instance of) str, not the str class itself, but this seems correct. # noinspection PyTypeChecker + dtypes: Dict[str, Union[Type[float], Type[str]]] = {column: str for column in self.columns} + dtypes[MetricsFileColumns.Dice.value] = float + dtypes[MetricsFileColumns.HausdorffDistanceMM.value] = float + dtypes[MetricsFileColumns.MeanDistanceMM.value] = float df = DataFrame(self.columns, dtype=str) - df[MetricsFileColumns.DiceNumeric.value] = pd.Series(data=df[MetricsFileColumns.Dice.value].apply(float)) - df[MetricsFileColumns.HausdorffDistanceMM.value] = pd.Series( - data=df[MetricsFileColumns.HausdorffDistanceMM.value].apply(float)) - df[MetricsFileColumns.MeanDistanceMM.value] = pd.Series( - data=df[MetricsFileColumns.MeanDistanceMM.value].apply(float)) + df = df.astype(dtypes) + df[MetricsFileColumns.DiceNumeric.value] = df[MetricsFileColumns.Dice.value] + df = df.sort_values(by=[MetricsFileColumns.Patient.value, MetricsFileColumns.Structure.value]) return df @@ -245,3 +261,15 @@ def convert_input_and_label(model_output: Union[torch.Tensor, np.ndarray], if not torch.is_tensor(label): label = torch.tensor(label) return model_output.float(), label.float() + + +def is_missing_ground_truth(ground_truth: np.array) -> bool: + """ + calculate_metrics_per_class in metrics.py and plot_contours_for_all_classes in plotting.py both + check whether there is ground truth missing using this simple check for NaN value at 0, 0, 0. + To avoid duplicate code we bring it here as a utility function. + :param ground_truth: ground truth binary array with dimensions: [Z x Y x X]. + :param label_id: Integer index of the label to check. + :returns: True if the label is missing (signified by NaN), False otherwise. + """ + return np.isnan(ground_truth[0, 0, 0]) diff --git a/Tests/ML/configs/DummyModel.py b/Tests/ML/configs/DummyModel.py index ba03f1819..5abcf19aa 100644 --- a/Tests/ML/configs/DummyModel.py +++ b/Tests/ML/configs/DummyModel.py @@ -16,6 +16,9 @@ class DummyModel(SegmentationModelBase): fg_ids = ["region"] + train_subject_ids = ['1', '2', '3'] + test_subject_ids = ['4', '7'] + val_subject_ids = ['5', '6'] def __init__(self, **kwargs: Any) -> None: super().__init__( @@ -60,16 +63,15 @@ def __init__(self, **kwargs: Any) -> None: weight_decay=1e-4, class_weights=[0.5, 0.5], detect_anomaly=False, - use_mixed_precision=False, - ) + use_mixed_precision=False) self.add_and_validate(kwargs) # Trying to run DDP from the test suite hangs, hence restrict to single GPU. self.max_num_gpus = 1 def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits: - return DatasetSplits(train=dataset_df[dataset_df.subject.isin(['1', '2', '3'])], - test=dataset_df[dataset_df.subject.isin(['4', '7'])], - val=dataset_df[dataset_df.subject.isin(['5', '6'])]) + return DatasetSplits(train=dataset_df[dataset_df.subject.isin(self.train_subject_ids)], + test=dataset_df[dataset_df.subject.isin(self.test_subject_ids)], + val=dataset_df[dataset_df.subject.isin(self.val_subject_ids)]) def get_parameter_search_hyperdrive_config(self, run_config: ScriptRunConfig) -> HyperDriveConfig: return super().get_parameter_search_hyperdrive_config(run_config) diff --git a/Tests/ML/pipelines/test_inference.py b/Tests/ML/pipelines/test_inference.py index 566597d8a..01a3202b6 100644 --- a/Tests/ML/pipelines/test_inference.py +++ b/Tests/ML/pipelines/test_inference.py @@ -2,9 +2,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Any, List +from InnerEye.ML.metrics_dict import MetricsDict +from typing import Any, List, Tuple import numpy as np +import pandas as pd import pytest import torch from torch.nn import Parameter @@ -18,6 +20,11 @@ from InnerEye.ML.pipelines.inference import InferencePipeline from InnerEye.ML.utils import image_util from Tests.ML.utils.test_model_util import create_model_and_store_checkpoint +from Tests.ML.configs.DummyModel import DummyModel +from InnerEye.ML.utils.split_dataset import DatasetSplits +from InnerEye.ML.dataset.sample import PatientMetadata, Sample +from InnerEye.ML.common import ModelExecutionMode +from InnerEye.ML.model_testing import store_inference_results, evaluate_model_predictions, populate_metrics_writer @pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows") @@ -198,3 +205,153 @@ def shrink_dim(i: int) -> int: def get_all_child_layers(self) -> List[torch.nn.Module]: return list() + + +def create_config_from_dataset(input_list: List[List[str]], train: List[str], val: List[str], test: List[str]) \ + -> DummyModel: + """ + Creates an "DummyModel(SegmentationModelBase)" object given patient list + and training, validation and test subjects id. + """ + + class MyDummyModel(DummyModel): + def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits: + return DatasetSplits(train=dataset_df[dataset_df.subject.isin(train)], + test=dataset_df[dataset_df.subject.isin(test)], + val=dataset_df[dataset_df.subject.isin(val)]) + + config = MyDummyModel() + # Sets two regions for ground truth + config.fg_ids = ["region", "region_1"] + config.ground_truth_ids = config.fg_ids + config.ground_truth_ids_display_names = config.fg_ids + config.colours = [(255, 255, 255)] * len(config.fg_ids) + config.fill_holes = [False] * len(config.fg_ids) + config.roi_interpreted_types = ["Organ"] * len(config.fg_ids) + config.check_exclusive = False + df = pd.DataFrame(input_list, columns=['subject', 'filePath', 'channel']) + config._dataset_data_frame = df + return config + + +def test_evaluate_model_predictions() -> None: + """ + Creates an 'InferencePipeline.Result' object using pre-defined volumes, stores results and evaluates metrics. + """ + # Patients 3, 4, and 5 are in test dataset such that: + # Patient 3 has one missing ground truth channel: "region" + # Patient 4 has all missing ground truth channels: "region", "region_1" + # Patient 5 has no missing ground truth channels. + input_list = [ + ["1", "train_and_test_data/id1_channel1.nii.gz", "channel1"], + ["1", "train_and_test_data/id1_channel1.nii.gz", "channel2"], + ["1", "train_and_test_data/id1_mask.nii.gz", "mask"], + ["1", "train_and_test_data/id1_region.nii.gz", "region"], + ["1", "train_and_test_data/id1_region.nii.gz", "region_1"], + ["2", "train_and_test_data/id2_channel1.nii.gz", "channel1"], + ["2", "train_and_test_data/id2_channel1.nii.gz", "channel2"], + ["2", "train_and_test_data/id2_mask.nii.gz", "mask"], + ["2", "train_and_test_data/id2_region.nii.gz", "region"], + ["2", "train_and_test_data/id2_region.nii.gz", "region_1"], + ["3", "train_and_test_data/id2_channel1.nii.gz", "channel1"], + ["3", "train_and_test_data/id2_channel1.nii.gz", "channel2"], + ["3", "train_and_test_data/id2_mask.nii.gz", "mask"], + # ["3", "train_and_test_data/id2_region.nii.gz", "region"], # commented on purpose + ["3", "train_and_test_data/id2_region.nii.gz", "region_1"], + ["4", "train_and_test_data/id2_channel1.nii.gz", "channel1"], + ["4", "train_and_test_data/id2_channel1.nii.gz", "channel2"], + ["4", "train_and_test_data/id2_mask.nii.gz", "mask"], + # ["4", "train_and_test_data/id2_region.nii.gz", "region"], # commented on purpose + # ["4", "train_and_test_data/id2_region.nii.gz", "region_1"], # commented on purpose + ["5", "train_and_test_data/id2_channel1.nii.gz", "channel1"], + ["5", "train_and_test_data/id2_channel1.nii.gz", "channel2"], + ["5", "train_and_test_data/id2_mask.nii.gz", "mask"], + ["5", "train_and_test_data/id2_region.nii.gz", "region"], + ["5", "train_and_test_data/id2_region.nii.gz", "region_1"]] + + config = create_config_from_dataset(input_list, train=['1'], val=['2'], test=['3', '4', '5']) + config.allow_incomplete_labels = True + ds = config.get_torch_dataset_for_inference(ModelExecutionMode.TEST) + results_folder = config.outputs_folder + if not results_folder.is_dir(): + results_folder.mkdir() + + model_prediction_evaluations: List[Tuple[PatientMetadata, MetricsDict]] = [] + + for sample_index, sample in enumerate(ds, 1): + sample = Sample.from_dict(sample=sample) + posteriors = np.zeros((3,) + sample.mask.shape, 'float32') + posteriors[0][:] = 0.2 + posteriors[1][:] = 0.6 + posteriors[2][:] = 0.2 + + assert config.dataset_expected_spacing_xyz is not None + + inference_result = InferencePipeline.Result( + patient_id=sample.patient_id, + posteriors=posteriors, + segmentation=np.argmax(posteriors, 0), + voxel_spacing_mm=config.dataset_expected_spacing_xyz + ) + store_inference_results(inference_result=inference_result, + config=config, + results_folder=results_folder, + image_header=sample.metadata.image_header) + + metadata, metrics_per_class = evaluate_model_predictions( + sample_index - 1, + config=config, + dataset=ds, + results_folder=results_folder) + + model_prediction_evaluations.append((metadata, metrics_per_class)) + + # Patient 3 has one missing ground truth channel: "region" + if sample.metadata.patient_id == '3': + assert 'Dice' in metrics_per_class.values('region_1').keys() + assert 'HausdorffDistance_millimeters' in metrics_per_class.values('region_1').keys() + assert 'MeanSurfaceDistance_millimeters' in metrics_per_class.values('region_1').keys() + for hue_name in ['region', 'Default']: + for metric_type in metrics_per_class.values(hue_name).keys(): + assert np.isnan(metrics_per_class.values(hue_name)[metric_type]).all() + + # Patient 4 has all missing ground truth channels: "region", "region_1" + if sample.metadata.patient_id == '4': + for hue_name in ['region_1', 'region', 'Default']: + for metric_type in metrics_per_class.values(hue_name).keys(): + assert np.isnan(metrics_per_class.values(hue_name)[metric_type]).all() + + # Patient 5 has no missing ground truth channels + if sample.metadata.patient_id == '5': + for metric_type in metrics_per_class.values('Default').keys(): + assert np.isnan(metrics_per_class.values('Default')[metric_type]).all() + for hue_name in ['region_1', 'region']: + assert 'Dice' in metrics_per_class.values(hue_name).keys() + assert 'HausdorffDistance_millimeters' in metrics_per_class.values(hue_name).keys() + assert 'MeanSurfaceDistance_millimeters' in metrics_per_class.values(hue_name).keys() + + metrics_writer, average_dice = populate_metrics_writer(model_prediction_evaluations, config) + # Patient 3 has only one missing ground truth channel + assert not np.isnan(average_dice[0]) + assert np.isnan(float(metrics_writer.columns["Dice"][0])) + assert not np.isnan(float(metrics_writer.columns["Dice"][1])) + assert np.isnan(float(metrics_writer.columns["HausdorffDistance_mm"][0])) + assert not np.isnan(float(metrics_writer.columns["HausdorffDistance_mm"][1])) + assert np.isnan(float(metrics_writer.columns["MeanDistance_mm"][0])) + assert not np.isnan(float(metrics_writer.columns["MeanDistance_mm"][1])) + # Patient 4 has all missing ground truth channels + assert np.isnan(average_dice[1]) + assert np.isnan(float(metrics_writer.columns["Dice"][2])) + assert np.isnan(float(metrics_writer.columns["Dice"][3])) + assert np.isnan(float(metrics_writer.columns["HausdorffDistance_mm"][2])) + assert np.isnan(float(metrics_writer.columns["HausdorffDistance_mm"][3])) + assert np.isnan(float(metrics_writer.columns["MeanDistance_mm"][2])) + assert np.isnan(float(metrics_writer.columns["MeanDistance_mm"][3])) + # Patient 5 has no missing ground truth channels. + assert average_dice[2] > 0 + assert float(metrics_writer.columns["Dice"][4]) >= 0 + assert float(metrics_writer.columns["Dice"][5]) >= 0 + assert float(metrics_writer.columns["HausdorffDistance_mm"][4]) >= 0 + assert float(metrics_writer.columns["HausdorffDistance_mm"][5]) >= 0 + assert float(metrics_writer.columns["MeanDistance_mm"][4]) >= 0 + assert float(metrics_writer.columns["MeanDistance_mm"][5]) >= 0 diff --git a/Tests/ML/reports/test_segmentation_report.py b/Tests/ML/reports/test_segmentation_report.py index efabac6d0..12c51fa91 100644 --- a/Tests/ML/reports/test_segmentation_report.py +++ b/Tests/ML/reports/test_segmentation_report.py @@ -2,24 +2,35 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import re from io import StringIO from pathlib import Path import pandas as pd import pytest +from numpy.core.numeric import NaN +from InnerEye.Common.common_util import is_windows +from InnerEye.Common.fixed_paths_for_tests import tests_root_directory from InnerEye.Common.metrics_constants import MetricsFileColumns from InnerEye.Common.output_directories import OutputFolderForTests -from InnerEye.Common.common_util import is_windows from InnerEye.ML.reports.notebook_report import generate_segmentation_notebook from InnerEye.ML.reports.segmentation_report import describe_score, worst_patients_and_outliers from InnerEye.ML.utils.csv_util import COL_IS_OUTLIER @pytest.mark.skipif(is_windows(), reason="Random timeout errors on windows.") -def test_generate_segmentation_report(test_output_dirs: OutputFolderForTests) -> None: - reports_folder = Path(__file__).parent +@pytest.mark.parametrize("use_partial_ground_truth", [False, True]) +def test_generate_segmentation_report(test_output_dirs: OutputFolderForTests, use_partial_ground_truth: bool) -> None: + reports_folder = tests_root_directory() / "ML" / "reports" metrics_file = reports_folder / "metrics_hn.csv" + if use_partial_ground_truth: + return _test_generate_segmentation_report_with_partial_ground_truth(test_output_dirs, metrics_file) + return _test_generate_segmentation_report_without_partial_ground_truth(test_output_dirs, metrics_file) + +def _test_generate_segmentation_report_without_partial_ground_truth( + test_output_dirs: OutputFolderForTests, + metrics_file: Path) -> None: current_dir = test_output_dirs.make_sub_dir("test_segmentation_report") result_file = current_dir / "report.ipynb" result_html = generate_segmentation_notebook(result_notebook=result_file, @@ -31,6 +42,43 @@ def test_generate_segmentation_report(test_output_dirs: OutputFolderForTests) -> contents = result_html.read_text(encoding='utf-8') assert 'parotid_r' in contents +def _test_generate_segmentation_report_with_partial_ground_truth( + test_output_dirs: OutputFolderForTests, + original_metrics_file: Path) -> None: + """ + The test without partial ground truth should cover more detail, here we just check that providing + partial ground truth results in some labels having a lower user count. + """ + original_metrics = pd.read_csv(original_metrics_file) + partial_metrics = original_metrics + partial_metrics.loc[partial_metrics['Structure'].eq('brainstem') & partial_metrics['Patient'].isin([14, 15, 19]), + ['Dice', 'HausdorffDistance_mm', 'MeanDistance_mm']] = NaN + current_dir = test_output_dirs.make_sub_dir("test_segmentation_report") + partial_metrics_file = current_dir / "metrics_hn.csv" + result_file = current_dir / "report.ipynb" + partial_metrics.to_csv(partial_metrics_file, index=False, float_format="%.3f", na_rep="") + result_html = generate_segmentation_notebook(result_notebook=result_file, test_metrics=partial_metrics_file) + result_html_text = result_html.read_text(encoding='utf-8') + # Look for this row in the HTML Dice table: + # brainstem\n 0.82600\n 0.8570\n 0.87600\n 17.0\n + # It shows that for the brainstem label there are only 17, not 20, patients with that label, + # because we removed the brainstem label for patients 14, 15, and 19. + + def get_patient_count_for_structure(structure: str, text: str) -> float: + regex = f"{structure}" + r"<\/td>(\n\s*[0-9\.]*<\/td>){3}\n\s*([0-9\.]*)" + # which results in, for example, this regex: + # regex = "brainstem<\/td>(\n\s*[0-9\.]*<\/td>){3}\n\s*([0-9\.]*)" + match = re.search(regex, text) + if not match: + return NaN + patient_count_as_string = match.group(2) + return float(patient_count_as_string) + + num_patients_with_lacrimal_gland_l_label = get_patient_count_for_structure("lacrimal_gland_l", result_html_text) + num_patients_with_brainstem_label = get_patient_count_for_structure("brainstem", result_html_text) + assert num_patients_with_lacrimal_gland_l_label == 20.0 + assert num_patients_with_brainstem_label == 17.0 + def test_describe_metric() -> None: data = """Patient,Structure,Dice,HausdorffDistance_mm,MeanDistance_mm diff --git a/Tests/ML/test_lightning_containers.py b/Tests/ML/test_lightning_containers.py index 191564f0c..5cf370ae0 100644 --- a/Tests/ML/test_lightning_containers.py +++ b/Tests/ML/test_lightning_containers.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------------------ from io import StringIO from pathlib import Path -from typing import List +from typing import List, Optional, Tuple from unittest import mock import pandas as pd @@ -283,6 +283,7 @@ def test_container_hooks(test_output_dirs: OutputFolderForTests) -> None: for file in ["global_rank_zero.txt", "local_rank_zero.txt", "all_ranks.txt"]: assert (runner.container.outputs_folder / file).is_file(), f"Missing file: {file}" + @pytest.mark.parametrize("number_of_cross_validation_splits", [0, 2]) def test_get_hyperdrive_config(number_of_cross_validation_splits: int, test_output_dirs: OutputFolderForTests) -> None: @@ -312,3 +313,36 @@ def test_get_hyperdrive_config(number_of_cross_validation_splits: int, else: hd_config = container.get_hyperdrive_config(run_config=run_config) assert isinstance(hd_config, HyperDriveConfig) + + +@pytest.mark.parametrize("allow_partial_ground_truth", [True, False]) +def test_innereyecontainer_setup_passes_on_allow_incomplete_labels( + test_output_dirs: OutputFolderForTests, + allow_partial_ground_truth: bool) -> None: + """ + Test that InnerEyeContainer.setup passes on the correct value of allow_incomplete_labels to + full_image_dataset.convert_channels_to_file_paths + :param test_output_dirs: Test fixture. + :param allow_partial_ground_truth: The value to set allow_incomplete_labels to and check it is + passed through. + """ + config = DummyModel() + config.set_output_to(test_output_dirs.root_dir) + config.allow_incomplete_labels = allow_partial_ground_truth + container = InnerEyeContainer(config) + test_done_message = "Stop now, the test has passed." + + def mocked_convert_channels_to_file_paths( + _: List[str], + __: pd.DataFrame, + ___: Path, + ____: str, + allow_incomplete_labels: bool) -> Tuple[List[Optional[Path]], str]: + assert allow_incomplete_labels == allow_partial_ground_truth + raise RuntimeError(test_done_message) + + with pytest.raises(RuntimeError) as runtime_error: + with mock.patch("InnerEye.ML.lightning_base.convert_channels_to_file_paths") as convert_channels_to_file_paths_mock: + convert_channels_to_file_paths_mock.side_effect = mocked_convert_channels_to_file_paths + container.setup() + assert str(runtime_error.value) == test_done_message diff --git a/Tests/ML/test_model_testing.py b/Tests/ML/test_model_testing.py index ebda3d3b8..39ab580e5 100644 --- a/Tests/ML/test_model_testing.py +++ b/Tests/ML/test_model_testing.py @@ -2,14 +2,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ + import numpy as np import pandas as pd import pytest from pytorch_lightning import seed_everything from InnerEye.Common import common_util -from InnerEye.Common.common_util import get_best_epoch_results_path +from InnerEye.Common.common_util import METRICS_AGGREGATES_FILE, SUBJECT_METRICS_FILE_NAME, get_best_epoch_results_path from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path +from InnerEye.Common.metrics_constants import MetricsFileColumns from InnerEye.Common.output_directories import OutputFolderForTests from InnerEye.ML import model_testing from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, DATASET_CSV_FILE_NAME, ModelExecutionMode @@ -24,89 +26,152 @@ from InnerEye.ML.visualizers.plot_cross_validation import get_config_and_results_for_offline_runs from Tests.ML.configs.ClassificationModelForTesting import ClassificationModelForTesting from Tests.ML.configs.DummyModel import DummyModel -from Tests.ML.util import assert_file_contains_string, assert_nifti_content, assert_text_files_match, \ - get_default_checkpoint_handler, get_image_shape +from Tests.ML.util import (assert_file_contains_string, assert_nifti_content, assert_text_files_match, + csv_column_contains_value, get_default_checkpoint_handler, get_image_shape) from Tests.ML.utils.test_model_util import create_model_and_store_checkpoint @pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows") -def test_model_test(test_output_dirs: OutputFolderForTests) -> None: +@pytest.mark.parametrize(["use_partial_ground_truth", "allow_partial_ground_truth"], [[True, True], [True, False], [False, False]]) +def test_model_test( + test_output_dirs: OutputFolderForTests, + use_partial_ground_truth: bool, + allow_partial_ground_truth: bool) -> None: + """ + Check the CSVs (and image files) output by InnerEye.ML.model_testing.segmentation_model_test + :param test_output_dirs: The fixture in conftest.py + :param use_partial_ground_truth: Whether to remove some ground truth labels from some test users + :param allow_partial_ground_truth: What to set the allow_incomplete_labels flag to + """ train_and_test_data_dir = full_ml_test_data_path("train_and_test_data") seed_everything(42) config = DummyModel() + config.allow_incomplete_labels = allow_partial_ground_truth config.set_output_to(test_output_dirs.root_dir) placeholder_dataset_id = "place_holder_dataset_id" config.azure_dataset_id = placeholder_dataset_id transform = config.get_full_image_sample_transforms().test df = pd.read_csv(full_ml_test_data_path(DATASET_CSV_FILE_NAME)) - df = df[df.subject.isin([1, 2])] - # noinspection PyTypeHints - config._datasets_for_inference = \ - {ModelExecutionMode.TEST: FullImageDataset(config, df, full_image_sample_transforms=transform)} # type: ignore + + if use_partial_ground_truth: + config.check_exclusive = False + config.ground_truth_ids = ["region", "region_1"] + + # As in Tests.ML.pipelines.test.inference.test_evaluate_model_predictions patients 3, 4, + # and 5 are in the test dataset with: + # Patient 3 has one missing ground truth channel: "region" + df = df[df["subject"].ne(3) | df["channel"].ne("region")] + # Patient 4 has all missing ground truth channels: "region", "region_1" + df = df[df["subject"].ne(4) | df["channel"].ne("region")] + df = df[df["subject"].ne(4) | df["channel"].ne("region_1")] + # Patient 5 has no missing ground truth channels. + + config.dataset_data_frame = df + + df = df[df.subject.isin([3, 4, 5])] + + config.train_subject_ids = ['1', '2'] + config.test_subject_ids = ['3', '4', '5'] + config.val_subject_ids = ['6', '7'] + else: + df = df[df.subject.isin([1, 2])] + + if use_partial_ground_truth and not allow_partial_ground_truth: + with pytest.raises(ValueError) as value_error: + # noinspection PyTypeHints + config._datasets_for_inference = { + ModelExecutionMode.TEST: + FullImageDataset( + config, + df, + full_image_sample_transforms=transform)} # type: ignore + assert "Patient 3 does not have channel 'region'" in str(value_error.value) + return + else: + # noinspection PyTypeHints + config._datasets_for_inference = { + ModelExecutionMode.TEST: + FullImageDataset( + config, + df, + full_image_sample_transforms=transform)} # type: ignore execution_mode = ModelExecutionMode.TEST - checkpoint_handler = get_default_checkpoint_handler(model_config=config, - project_root=test_output_dirs.root_dir) + checkpoint_handler = get_default_checkpoint_handler(model_config=config, project_root=test_output_dirs.root_dir) # Mimic the behaviour that checkpoints are downloaded from blob storage into the checkpoints folder. create_model_and_store_checkpoint(config, config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX) checkpoint_handler.additional_training_done() inference_results = model_testing.segmentation_model_test(config, - data_split=execution_mode, + execution_mode=execution_mode, checkpoint_handler=checkpoint_handler) epoch_dir = config.outputs_folder / get_best_epoch_results_path(execution_mode) - assert inference_results.metrics == pytest.approx(0.66606902, abs=1e-6) - - assert config.outputs_folder.is_dir() - assert epoch_dir.is_dir() - patient1 = io_util.load_nifti_image(train_and_test_data_dir / "id1_channel1.nii.gz") - patient2 = io_util.load_nifti_image(train_and_test_data_dir / "id2_channel1.nii.gz") - - assert_file_contains_string(epoch_dir / DATASET_ID_FILE, placeholder_dataset_id) - assert_file_contains_string(epoch_dir / GROUND_TRUTH_IDS_FILE, "region") - assert_text_files_match(epoch_dir / model_testing.SUBJECT_METRICS_FILE_NAME, - train_and_test_data_dir / model_testing.SUBJECT_METRICS_FILE_NAME) - assert_text_files_match(epoch_dir / model_testing.METRICS_AGGREGATES_FILE, - train_and_test_data_dir / model_testing.METRICS_AGGREGATES_FILE) - # Plotting results vary between platforms. Can only check if the file is generated, but not its contents. - assert (epoch_dir / model_testing.BOXPLOT_FILE).exists() - - assert_nifti_content(epoch_dir / "001" / "posterior_region.nii.gz", get_image_shape(patient1), - patient1.header, - [137], np.ubyte) - assert_nifti_content(epoch_dir / "002" / "posterior_region.nii.gz", get_image_shape(patient2), - patient2.header, - [137], np.ubyte) - assert_nifti_content(epoch_dir / "001" / DEFAULT_RESULT_IMAGE_NAME, get_image_shape(patient1), - patient1.header, - [1], np.ubyte) - assert_nifti_content(epoch_dir / "002" / DEFAULT_RESULT_IMAGE_NAME, get_image_shape(patient2), - patient2.header, - [1], np.ubyte) - assert_nifti_content(epoch_dir / "001" / "posterior_background.nii.gz", get_image_shape(patient1), - patient1.header, - [117], np.ubyte) - assert_nifti_content(epoch_dir / "002" / "posterior_background.nii.gz", get_image_shape(patient2), - patient2.header, - [117], np.ubyte) - thumbnails_folder = epoch_dir / model_testing.THUMBNAILS_FOLDER - assert thumbnails_folder.is_dir() - png_files = list(thumbnails_folder.glob("*.png")) - overlays = [f for f in png_files if "_region_slice_" in str(f)] - assert len(overlays) == len(df.subject.unique()), "There should be one overlay/contour file per subject" - - # Writing dataset.csv normally happens at the beginning of training, - # but this test reads off a saved checkpoint file. - # Dataset.csv must be present for plot_cross_validation. - config.write_dataset_files() - # Test if the metrics files can be picked up correctly by the cross validation code - config_and_files = get_config_and_results_for_offline_runs(config) - result_files = config_and_files.files - assert len(result_files) == 1 - for file in result_files: - assert file.execution_mode == execution_mode - assert file.dataset_csv_file is not None - assert file.dataset_csv_file.exists() - assert file.metrics_file is not None - assert file.metrics_file.exists() + total_num_patients_column_name = f"total_{MetricsFileColumns.Patient.value}".lower() + if not total_num_patients_column_name.endswith("s"): + total_num_patients_column_name += "s" + + if use_partial_ground_truth: + num_subjects = len(pd.unique(df["subject"])) + if allow_partial_ground_truth: + assert csv_column_contains_value( + csv_file_path=epoch_dir / METRICS_AGGREGATES_FILE, + column_name=total_num_patients_column_name, + value=num_subjects, + contains_only_value=True) + assert csv_column_contains_value( + csv_file_path=epoch_dir / SUBJECT_METRICS_FILE_NAME, + column_name=MetricsFileColumns.Dice.value, + value='', + contains_only_value=False) + else: + aggregates_df = pd.read_csv(epoch_dir / METRICS_AGGREGATES_FILE) + assert total_num_patients_column_name not in aggregates_df.columns # Only added if using partial ground truth + + assert not csv_column_contains_value( + csv_file_path=epoch_dir / SUBJECT_METRICS_FILE_NAME, + column_name=MetricsFileColumns.Dice.value, + value='', + contains_only_value=False) + + assert inference_results.metrics == pytest.approx(0.66606902, abs=1e-6) + assert config.outputs_folder.is_dir() + assert epoch_dir.is_dir() + patient1 = io_util.load_nifti_image(train_and_test_data_dir / "id1_channel1.nii.gz") + patient2 = io_util.load_nifti_image(train_and_test_data_dir / "id2_channel1.nii.gz") + + assert_file_contains_string(epoch_dir / DATASET_ID_FILE, placeholder_dataset_id) + assert_file_contains_string(epoch_dir / GROUND_TRUTH_IDS_FILE, "region") + assert_text_files_match(epoch_dir / model_testing.SUBJECT_METRICS_FILE_NAME, + train_and_test_data_dir / model_testing.SUBJECT_METRICS_FILE_NAME) + assert_text_files_match(epoch_dir / model_testing.METRICS_AGGREGATES_FILE, + train_and_test_data_dir / model_testing.METRICS_AGGREGATES_FILE) + # Plotting results vary between platforms. Can only check if the file is generated, but not its contents. + assert (epoch_dir / model_testing.BOXPLOT_FILE).exists() + + assert_nifti_content(epoch_dir / "001" / "posterior_region.nii.gz", get_image_shape(patient1), patient1.header, [137], np.ubyte) + assert_nifti_content(epoch_dir / "002" / "posterior_region.nii.gz", get_image_shape(patient2), patient2.header, [137], np.ubyte) + assert_nifti_content(epoch_dir / "001" / DEFAULT_RESULT_IMAGE_NAME, get_image_shape(patient1), patient1.header, [1], np.ubyte) + assert_nifti_content(epoch_dir / "002" / DEFAULT_RESULT_IMAGE_NAME, get_image_shape(patient2), patient2.header, [1], np.ubyte) + assert_nifti_content(epoch_dir / "001" / "posterior_background.nii.gz", get_image_shape(patient1), patient1.header, [117], np.ubyte) + assert_nifti_content(epoch_dir / "002" / "posterior_background.nii.gz", get_image_shape(patient2), patient2.header, [117], np.ubyte) + thumbnails_folder = epoch_dir / model_testing.THUMBNAILS_FOLDER + assert thumbnails_folder.is_dir() + png_files = list(thumbnails_folder.glob("*.png")) + overlays = [f for f in png_files if "_region_slice_" in str(f)] + assert len(overlays) == len(df.subject.unique()), "There should be one overlay/contour file per subject" + + # Writing dataset.csv normally happens at the beginning of training, + # but this test reads off a saved checkpoint file. + # Dataset.csv must be present for plot_cross_validation. + config.write_dataset_files() + # Test if the metrics files can be picked up correctly by the cross validation code + config_and_files = get_config_and_results_for_offline_runs(config) + result_files = config_and_files.files + assert len(result_files) == 1 + for file in result_files: + assert file.execution_mode == execution_mode + assert file.dataset_csv_file is not None + assert file.dataset_csv_file.exists() + assert file.metrics_file is not None + assert file.metrics_file.exists() @pytest.mark.parametrize("config", [DummyModel(), ClassificationModelForTesting()]) diff --git a/Tests/ML/util.py b/Tests/ML/util.py index 34b21d8b2..38397ad96 100644 --- a/Tests/ML/util.py +++ b/Tests/ML/util.py @@ -7,6 +7,7 @@ from typing import Any, List, Optional, Tuple, Union import numpy as np +import pandas as pd import pytest import torch from PIL import Image @@ -73,7 +74,8 @@ def load_train_and_test_data_channels(patient_ids: List[int], metadata=PatientMetadata(patient_id=z), image_channels=[file_name(z, c) for c in TEST_CHANNEL_IDS], mask_channel=file_name(z, TEST_MASK_ID), - ground_truth_channels=[file_name(z, TEST_GT_ID)] + ground_truth_channels=[file_name(z, TEST_GT_ID)], + allow_incomplete_labels=False )) samples = [] @@ -98,7 +100,7 @@ def assert_file_contains_string(full_file: Union[str, Path], expected: Any = Non file_path = full_file if isinstance(full_file, Path) else Path(full_file) assert_file_exists(file_path) if expected is not None: - _assert_line(file_path.read_text(), expected) + assert expected.strip() in file_path.read_text() def assert_text_files_match(full_file: Path, expected_file: Path) -> None: @@ -188,6 +190,37 @@ def assert_binary_files_match(actual_file: Path, expected_file: Path) -> None: assert False, f"File contents does not match: len(actual)={len(actual)}, len(expected)={len(expected)}" +def csv_column_contains_value( + csv_file_path: Path, + column_name: str, + value: Any, + contains_only_value: bool = True) -> bool: + """ + Checks that the column in the csv file contains the given value (and perhaps only contains that value) + :param csv_file_path: The path to the CSV + :param column_name: The name of the column in which we look for the value + :param value: The value to look for + :param contains_only_value: Check that this is the only value in the column (default True) + :returns: Boolean, whether the CSV column contains the value (and perhaps only the value) + """ + result = True + if not csv_file_path.exists: + raise ValueError(f"The CSV at {csv_file_path} does not exist.") + df = pd.read_csv(csv_file_path) + if column_name not in df.columns: + ValueError(f"The column {column_name} is not in the CSV at {csv_file_path}, which has columns {df.columns}.") + if value: + result = result and value in df[column_name].unique() + else: + result = result and df[column_name].isnull().any() + if contains_only_value: + if value: + result = result and df[column_name].nunique(dropna=True) == 1 + else: + result = result and df[column_name].nunique(dropna=True) == 0 + return result + + DummyPatientMetadata = PatientMetadata(patient_id='42') diff --git a/Tests/ML/utils/test_io_util.py b/Tests/ML/utils/test_io_util.py index 141c997b8..b494e5588 100644 --- a/Tests/ML/utils/test_io_util.py +++ b/Tests/ML/utils/test_io_util.py @@ -85,14 +85,17 @@ def test_load_images_from_dataset_source( # metadata, image and GT channels must be present. Mask is optional if None in [metadata, image_channel, ground_truth_channel]: with pytest.raises(Exception): - _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, check_exclusive) + _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, + check_exclusive) else: if check_exclusive: with pytest.raises(ValueError) as mutually_exclusive_labels_error: - _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, check_exclusive) + _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, + check_exclusive) assert 'not mutually exclusive' in str(mutually_exclusive_labels_error.value) else: - _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, check_exclusive) + _test_load_images_from_channels(metadata, image_channel, ground_truth_channel, mask_channel, + check_exclusive) def _test_load_images_from_channels( diff --git a/docs/sample_tasks.md b/docs/sample_tasks.md index 665403086..92d6e567d 100644 --- a/docs/sample_tasks.md +++ b/docs/sample_tasks.md @@ -3,13 +3,14 @@ This document contains two sample tasks for the classification and segmentation pipelines. The document will walk through the steps in [Training Steps](building_models.md), but with specific examples for each task. -Before trying tp train these models, you should have followed steps to set up an [environment](environment.md) and [AzureML](setting_up_aml.md) +Before trying to train these models, you should have followed steps to set up an [environment](environment.md) and [AzureML](setting_up_aml.md) ## Sample classification task: Glaucoma Detection on OCT volumes This example is based on the paper [A feature agnostic approach for glaucoma detection in OCT volumes](https://arxiv.org/pdf/1807.04855v3.pdf). ### Downloading and preparing the dataset + The dataset is available [here](https://zenodo.org/record/1481223#.Xs-ehzPiuM_) [[1]](#1). After downloading and extracting the zip file, run the [create_glaucoma_dataset_csv.py](https://github.com/microsoft/InnerEye-DeepLearning/blob/main/InnerEye/Scripts/create_glaucoma_dataset_csv.py) @@ -26,7 +27,6 @@ description below). ### Creating the model configuration and starting training - Next, you need to create a configuration file `InnerEye/ML/configs/MyGlaucoma.py` which extends the GlaucomaPublic class like this: ```python @@ -75,6 +75,7 @@ into a folder in the `datasets` container, for example `my_lung_dataset`. This f `azure_dataset_id` field of the model configuration, see below. ### Creating the model configuration and starting training + You can then create a new model configuration, based on the template [Lung.py](../InnerEye/ML/configs/segmentation/Lung.py). To do this, create a file `InnerEye/ML/configs/segmentation/MyLungModel.py`, where you create a subclass of the template Lung model, and