diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ee607ae1..7debf19a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ created. ## Upcoming ### Added +- ([#565](https://github.com/microsoft/InnerEye-DeepLearning/pull/565)) All `LightningContainer` models have two new commandline flags `pl_limit_train_batches` and `pl_limit_val_batches` to set the number of batches per epoch. Use this to speed up training (for example, when debugging) - ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference module on test data with partial ground truth files. (Also [522](https://github.com/microsoft/InnerEye-DeepLearning/pull/522).) - ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference. @@ -44,6 +45,7 @@ gets uploaded to AzureML, by skipping all test folders. - ([#554](https://github.com/microsoft/InnerEye-DeepLearning/pull/554)) Updated report in CovidModel. Set parameters in the config to run inference on both the validation and test sets by default. - ([#566](https://github.com/microsoft/InnerEye-DeepLearning/pull/566)) Update `hi-ml` dependency to `hi-ml-azure`. +- ([#565](https://github.com/microsoft/InnerEye-DeepLearning/pull/565)) The semantics of the SSL parameter `ssl_training_batch_size` changed from "effective batch size" (across all GPUs) to "batch size per GPU" ### Fixed - ([#537](https://github.com/microsoft/InnerEye-DeepLearning/pull/537)) Print warning if inference is disabled but comparison requested. @@ -69,6 +71,7 @@ in inference-only runs when using lightning containers. correctly in the SimCLR module - ([#558](https://github.com/microsoft/InnerEye-DeepLearning/pull/558)) Fix issue with the CovidModel config where model weights from a finetuning run were incompatible with the model architecture created for non-finetuning runs. +- ([#565](https://github.com/microsoft/InnerEye-DeepLearning/pull/565)) Checkpoints from SSL training now contain both optimizers, hence restarts after low priority preemption will correctly continue training of the linear head. ### Removed diff --git a/InnerEye-DataQuality/InnerEyeDataQuality/deep_learning/self_supervised/simclr_module.py b/InnerEye-DataQuality/InnerEyeDataQuality/deep_learning/self_supervised/simclr_module.py index 8f3998302..6752ccacc 100644 --- a/InnerEye-DataQuality/InnerEyeDataQuality/deep_learning/self_supervised/simclr_module.py +++ b/InnerEye-DataQuality/InnerEyeDataQuality/deep_learning/self_supervised/simclr_module.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ - from typing import Any import torch diff --git a/InnerEye/Common/Statistics/report_structure_extremes.py b/InnerEye/Common/Statistics/report_structure_extremes.py index 10740d04b..7f8363612 100644 --- a/InnerEye/Common/Statistics/report_structure_extremes.py +++ b/InnerEye/Common/Statistics/report_structure_extremes.py @@ -140,7 +140,7 @@ def report_structure_extremes_for_subject(subj_dir: str, series_id: str) -> Iter yield line_for_structure(subject, series_prefix, base, data) -def line_for_structure(subject: str, series_prefix: str, base: str, data: np.array) -> str: +def line_for_structure(subject: str, series_prefix: str, base: str, data: np.ndarray) -> str: """ :param subject: a subject, to include in the result :param series_prefix: first 8 characters (if any) of the series ID of the subject @@ -169,7 +169,7 @@ def line_for_structure(subject: str, series_prefix: str, base: str, data: np.arr return line -def extent_list(presence: np.array, max_value: int) -> Tuple[List[int], List[str]]: +def extent_list(presence: np.ndarray, max_value: int) -> Tuple[List[int], List[str]]: """ :param presence: a 1-D array of distinct integers in increasing order. :param max_value: any integer, not necessarily related to presence @@ -186,7 +186,7 @@ def extent_list(presence: np.array, max_value: int) -> Tuple[List[int], List[str return result, missing_ranges -def derive_missing_ranges(presence: np.array) -> List[str]: +def derive_missing_ranges(presence: np.ndarray) -> List[str]: """ :param presence: a 1-D array of distinct integers in increasing order. :return: a list of strings, each denoting a missing range of values within "presence". diff --git a/InnerEye/Common/type_annotations.py b/InnerEye/Common/type_annotations.py index 7c768437f..775161029 100644 --- a/InnerEye/Common/type_annotations.py +++ b/InnerEye/Common/type_annotations.py @@ -3,7 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ from pathlib import Path -from typing import Dict, Iterable, Optional, Tuple, TypeVar, Union +from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union T = TypeVar('T') PathOrString = Union[Path, str] @@ -15,3 +15,4 @@ TupleFloat9 = Tuple[float, float, float, float, float, float, float, float, float] IntOrTuple3 = Union[int, TupleInt3, Iterable] DictStrFloat = Dict[str, float] +DictStrFloatOrFloatList = Dict[str, Union[float, List[float]]] diff --git a/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py b/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py index 3641a0c98..4a023203b 100644 --- a/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py +++ b/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py @@ -136,6 +136,9 @@ def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, """ The train dataloaders """ + # This code may be superseded in current versions of PL. Using this dictionary syntax will effectively + # use a CombinedLoader(dataloaders, mode="max_size_cycle"), similar to what we need to do explicitly for + # the validation data loader. dataloaders = { SSLDataModuleType.ENCODER: self.encoder_module.train_dataloader(), SSLDataModuleType.LINEAR_HEAD: self.linear_head_module.train_dataloader()} diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_container.py b/InnerEye/ML/SSL/lightning_containers/ssl_container.py index d3f934042..ee034f146 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_container.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_container.py @@ -75,9 +75,9 @@ class SSLContainer(LightningContainer): "augmentations. Ignored for CIFAR10 example") ssl_training_dataset_name = param.ClassSelector(class_=SSLDatasetName, doc="The name of the dataset") ssl_training_batch_size = param.Integer( - doc="Total training batch size, will be divided across the number of gpus used for training. For example: if " - "you specify ssl_training_batch_size=1600 and use 4 nodes with 4 gpus each (i.e. total of 16 GPUs), " - "the code will provide a per-gpu batch size of 100") + doc="Training batch size per GPU. The effective batch size will be the number of GPUs times this number. " + "For example, if you specify ssl_training_batch_size=100 and use 4 nodes with 4 gpus each, " + "the effective batch size will be 1600.") ssl_training_type = param.ClassSelector(class_=SSLTrainingType, doc="Which algorithm to use for SSL training") ssl_encoder = param.ClassSelector(class_=EncoderName, doc="Which encoder to use for SSL") use_balanced_binary_loss_for_linear_head = param.Boolean(default=False, @@ -100,6 +100,9 @@ class SSLContainer(LightningContainer): def setup(self) -> None: from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer + if self.is_debug_model: + self.pl_limit_train_batches = 1 + self.pl_limit_val_batches = 1 self.total_num_gpus = self.num_gpus_per_node * self.num_nodes self._load_config() # If you're using the same data for training and linear head, allow the user to specify the dataset only @@ -169,6 +172,13 @@ def create_model(self) -> LightningModule: "num_classes": self.data_module.num_classes}) self.encoder_output_dim = get_encoder_output_dim(model, self.data_module) + self.online_eval_callback = \ + SSLOnlineEvaluatorInnerEye(class_weights=self.data_module.class_weights, # type: ignore + z_dim=self.encoder_output_dim, + num_classes=self.data_module.num_classes, # type: ignore + dataset=self.linear_head_dataset_name.value, # type: ignore + drop_p=0.2, + learning_rate=self.learning_rate_linear_head_during_ssl_training) return model def get_data_module(self) -> InnerEyeDataModuleTypes: @@ -199,16 +209,17 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio train_transforms, val_transforms = self._get_transforms(datamodule_args.augmentation_params, datamodule_args.dataset_name, is_ssl_encoder_module) - batch_size_per_gpu = datamodule_args.batch_size // self.total_num_gpus if self.total_num_gpus > 0 else \ - datamodule_args.batch_size - logging.info(f"Batch size per gpu: {batch_size_per_gpu}") + batch_multiplier = self.total_num_gpus if self.total_num_gpus > 0 else 1 + effective_batch_size = datamodule_args.batch_size * batch_multiplier + logging.info(f"Batch size per GPU: {datamodule_args.batch_size}") + logging.info(f"Effective batch size on {batch_multiplier} GPUs: {effective_batch_size}") dm = InnerEyeVisionDataModule(dataset_cls=self._SSLDataClassMappings[datamodule_args.dataset_name], return_index=not is_ssl_encoder_module, # index is only needed for linear head train_transforms=train_transforms, val_split=0.1, val_transforms=val_transforms, data_dir=str(datamodule_args.dataset_path), - batch_size=batch_size_per_gpu, + batch_size=datamodule_args.batch_size, num_workers=self.num_workers, seed=self.random_seed, drop_last=self.drop_last) @@ -226,9 +237,9 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode], :param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation pipeline to return. :param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is - applied on and it applies the training transformations also at validation time. Note that if your transformation - does not contain any randomness, the pipeline will return two identical copies. If False, it will return only one - transformation. + applied on and it applies the training transformations also at validation time. Note that if your transformation + does not contain any randomness, the pipeline will return two identical copies. If False, it will return only + one transformation. :return: training transformation pipeline and validation transformation pipeline. """ if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value, @@ -262,13 +273,5 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode], return train_transforms, val_transforms def get_trainer_arguments(self) -> Dict[str, Any]: - self.online_eval = SSLOnlineEvaluatorInnerEye(class_weights=self.data_module.class_weights, # type: ignore - z_dim=self.encoder_output_dim, - num_classes=self.data_module.num_classes, # type: ignore - dataset=self.linear_head_dataset_name.value, # type: ignore - drop_p=0.2, - learning_rate=self.learning_rate_linear_head_during_ssl_training) - trainer_kwargs: Dict[str, Any] = {"callbacks": self.online_eval} - if self.is_debug_model: - trainer_kwargs.update({"limit_train_batches": 1, "limit_val_batches": 1}) + trainer_kwargs: Dict[str, Any] = {"callbacks": self.online_eval_callback} return trainer_kwargs diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py b/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py index 76c8de85f..e890b75ca 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py @@ -64,7 +64,4 @@ def get_data_module(self) -> InnerEyeDataModuleTypes: return self.data_module def get_trainer_arguments(self) -> Dict[str, Any]: - trained_kwargs = {} - if self.is_debug_model: - trained_kwargs.update({"limit_train_batches": 1, "limit_val_batches": 1}) - return trained_kwargs + return {} diff --git a/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py b/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py index 603f45063..f897d68e7 100644 --- a/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py +++ b/InnerEye/ML/SSL/lightning_modules/byol/byol_module.py @@ -10,13 +10,14 @@ import torch import torch.nn.functional as F from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR +from pytorch_lightning import Trainer from torch import Tensor as T from torch.optim import Adam from InnerEye.ML.SSL.lightning_modules.byol.byol_models import SiameseArm from InnerEye.ML.SSL.lightning_modules.byol.byol_moving_average import ByolMovingAverageWeightUpdate from InnerEye.ML.SSL.utils import SSLDataModuleType -from pytorch_lightning import Trainer +from InnerEye.ML.lightning_loggers import log_learning_rate, log_on_epoch SingleBatchType = Tuple[List, T] BatchType = Union[Dict[SSLDataModuleType, SingleBatchType], SingleBatchType] @@ -98,14 +99,15 @@ def shared_step(self, batch: BatchType, batch_idx: int) -> T: return loss - def training_step(self, batch: BatchType, batch_idx: int, **kwargs: Any) -> T: # type: ignore + def training_step(self, batch: BatchType, batch_idx: int, **kwargs: Any) -> torch.Tensor: # type: ignore loss = self.shared_step(batch, batch_idx) - self.log_dict({'byol/train/loss': loss, 'byol/tau': self.weight_callback.current_tau}) + log_on_epoch(self, metrics={'byol/train/loss': loss, 'byol/tau': self.weight_callback.current_tau}) + log_learning_rate(self, name="byol/learning_rate") return loss def validation_step(self, batch: BatchType, batch_idx: int, **kwargs: Any) -> T: # type: ignore loss = self.shared_step(batch, batch_idx) - self.log_dict({'byol/val/loss': loss}) + log_on_epoch(self, 'byol/val/loss', loss) return loss def setup(self, *args: Any, **kwargs: Any) -> None: @@ -116,9 +118,10 @@ def configure_optimizers(self) -> Any: # exclude certain parameters parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(), weight_decay=self.hparams.weight_decay) # type: ignore - optimizer = Adam(parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore + optimizer = Adam(parameters, lr=self.hparams.learning_rate, # type: ignore + weight_decay=self.hparams.weight_decay) # type: ignore scheduler = LinearWarmupCosineAnnealingLR( - optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs) # type: ignore + optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs) # type: ignore return [optimizer], [scheduler] def exclude_from_wt_decay(self, @@ -144,4 +147,3 @@ def exclude_from_wt_decay(self, {'params': params, 'weight_decay': weight_decay}, {'params': excluded_params, 'weight_decay': 0.} ] - diff --git a/InnerEye/ML/SSL/lightning_modules/simclr_module.py b/InnerEye/ML/SSL/lightning_modules/simclr_module.py index f53446a9b..c9f707a28 100644 --- a/InnerEye/ML/SSL/lightning_modules/simclr_module.py +++ b/InnerEye/ML/SSL/lightning_modules/simclr_module.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ - from typing import Any, Dict, List, Tuple, Union import torch @@ -13,6 +12,7 @@ from InnerEye.ML.SSL.encoders import SSLEncoder from InnerEye.ML.SSL.utils import SSLDataModuleType +from InnerEye.ML.lightning_loggers import log_learning_rate, log_on_epoch SingleBatchType = Tuple[List, T] BatchType = Union[Dict[SSLDataModuleType, SingleBatchType], SingleBatchType] @@ -57,6 +57,17 @@ def __init__(self, encoder_name: str, dataset_name: str, use_7x7_first_conv_in_r def forward(self, x: torch.Tensor) -> torch.Tensor: return self.encoder(x) + def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: + loss = self.shared_step(batch) + log_on_epoch(self, "simclr/train/loss", loss, sync_dist=False) + log_learning_rate(self, name="simclr/learning_rate") + return loss + + def validation_step(self, batch: BatchType, batch_idx: int) -> T: # type: ignore + loss = self.shared_step(batch) + log_on_epoch(self, "simclr/val/loss", loss, sync_dist=False) + return loss + def shared_step(self, batch: BatchType) -> T: batch = batch[SSLDataModuleType.ENCODER] if isinstance(batch, dict) else batch @@ -72,6 +83,3 @@ def shared_step(self, batch: BatchType) -> T: loss = self.nt_xent_loss(z1, z2, self.temperature) return loss - - - diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py b/InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py index 805366f22..95c89e59a 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py @@ -5,11 +5,12 @@ from typing import Any, List, Optional import torch -from torchmetrics import Metric from pl_bolts.models.self_supervised import SSLEvaluator from torch.nn import functional as F +from torchmetrics import Metric from InnerEye.ML.SSL.encoders import get_encoder_output_dim +from InnerEye.ML.lightning_loggers import log_on_epoch from InnerEye.ML.dataset.scalar_sample import ScalarItem from InnerEye.ML.lightning_container import LightningModuleWithOptimizer from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve @@ -79,16 +80,16 @@ def shared_step(self, batch: Any, is_training: bool) -> Any: def training_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) -> Any: # type: ignore loss = self.shared_step(batch, True) - self.log("train/loss", loss, on_step=False, on_epoch=True) + log_on_epoch(self, "train/loss", loss) for metric in self.train_metrics: - self.log(f"train/{metric.name}", metric, on_epoch=True, on_step=False) + log_on_epoch(self, f"train/{metric.name}", metric) return loss def validation_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) -> None: # type: ignore loss = self.shared_step(batch, is_training=False) - self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=True) + log_on_epoch(self, 'val/loss', loss) for metric in self.val_metrics: - self.log(f"val/{metric.name}", metric, on_epoch=True, on_step=False) + log_on_epoch(self, f"val/{metric.name}", metric) def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]: """ diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 3eddaff53..75229de27 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -9,15 +9,19 @@ import torch from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.models.self_supervised.evaluator import SSLEvaluator -from torchmetrics import Metric from torch import Tensor as T from torch.nn import functional as F +from torchmetrics import Metric from InnerEye.ML.SSL.utils import SSLDataModuleType +from InnerEye.ML.lightning_loggers import log_on_epoch from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve BatchType = Union[Dict[SSLDataModuleType, Any], Any] +OPTIMIZER_STATE_NAME = "evaluator_optimizer" +EVALUATOR_STATE_NAME = "evaluator_weights" + class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator): def __init__(self, @@ -29,7 +33,6 @@ def __init__(self, :param class_weights: The class weights to use when computing the cross entropy loss. If set to None, no weighting will be done. - :param length_linear_head_loader: The maximum number of batches in the dataloader for the linear head. """ super().__init__(**kwargs) @@ -43,6 +46,28 @@ def __init__(self, Accuracy05()] \ if self.num_classes == 2 else [Accuracy05()] self.class_weights = class_weights + self.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim, + n_classes=self.num_classes, + p=self.drop_p, + n_hidden=self.hidden_dim) + self.optimizer = torch.optim.Adam(self.non_linear_evaluator.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay) + + def on_save_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, + checkpoint: Dict[str, Any]) -> Dict[str, Any]: + # Each callback gets its own state dictionary, that are fed back in during load + return { + OPTIMIZER_STATE_NAME: self.optimizer.state_dict(), + EVALUATOR_STATE_NAME: self.non_linear_evaluator.state_dict() + } + + def on_load_checkpoint(self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + callback_state: Dict[str, Any]) -> None: + self.optimizer.load_state_dict(callback_state[OPTIMIZER_STATE_NAME]) + self.non_linear_evaluator.load_state_dict(callback_state[EVALUATOR_STATE_NAME]) def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ @@ -50,15 +75,7 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning """ for metric in [*self.train_metrics, *self.val_metrics]: metric.to(device=pl_module.device) # type: ignore - - pl_module.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim, - n_classes=self.num_classes, - p=self.drop_p, - n_hidden=self.hidden_dim).to(pl_module.device) - assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module) - self.optimizer = torch.optim.Adam(pl_module.non_linear_evaluator.parameters(), - lr=self.learning_rate, - weight_decay=self.weight_decay) + self.non_linear_evaluator = self.non_linear_evaluator.to(pl_module.device) @staticmethod def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]: @@ -86,10 +103,9 @@ def shared_step(self, batch: BatchType, pl_module: pl.LightningModule, is_traini with torch.no_grad(): representations = self.get_representations(pl_module, x) representations = representations.detach() - assert isinstance(pl_module.non_linear_evaluator, torch.nn.Module) # Run the linear-head with SSL embeddings. - mlp_preds = pl_module.non_linear_evaluator(representations) + mlp_preds = self.non_linear_evaluator(representations) weights = None if self.class_weights is None else self.class_weights.to(device=pl_module.device) mlp_loss = F.cross_entropy(mlp_preds, y, weight=weights) @@ -112,16 +128,20 @@ def on_validation_batch_end(self, trainer: pl.Trainer, ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist()) if ids_linear_head not in self.visited_ids: self.visited_ids.add(ids_linear_head) + old_mode = self.non_linear_evaluator.training + self.non_linear_evaluator.eval() loss = self.shared_step(batch, pl_module, is_training=False) - pl_module.log('ssl_online_evaluator/val/loss', loss, on_step=False, on_epoch=True, sync_dist=False) + log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss, sync_dist=False) for metric in self.val_metrics: - pl_module.log(f"ssl_online_evaluator/val/{metric.name}", metric, on_epoch=True, - on_step=False) # type: ignore + log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric, sync_dist=False) + self.non_linear_evaluator.train(old_mode) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore """ Get and log training metrics, perform network update. """ + # Similar code should also live in the encoder training. + # There is a silent assumption here that SSL data is larger than linear head data ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist()) if ids_linear_head not in self.visited_ids: self.visited_ids.add(ids_linear_head) @@ -131,7 +151,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data self.optimizer.step() # log metrics - pl_module.log('ssl_online_evaluator/train/loss', loss) + log_on_epoch(pl_module, 'ssl_online_evaluator/train/loss', loss, sync_dist=False) for metric in self.train_metrics: - pl_module.log(f"ssl_online_evaluator/train/online_{metric.name}", metric, on_epoch=True, - on_step=False) # type: ignore + log_on_epoch(pl_module, f"ssl_online_evaluator/train/online_{metric.name}", metric, sync_dist=False) diff --git a/InnerEye/ML/SSL/utils.py b/InnerEye/ML/SSL/utils.py index 8cc757abb..ec779cae0 100644 --- a/InnerEye/ML/SSL/utils.py +++ b/InnerEye/ML/SSL/utils.py @@ -9,6 +9,7 @@ from typing import Any, Optional import torch +from pytorch_lightning import LightningModule from yacs.config import CfgNode from InnerEye.ML.SSL import ssl_augmentation_config @@ -81,14 +82,10 @@ def create_ssl_image_classifier(num_classes: int, logging.info(f"Loading pretrained {ssl_type} weights from:\n {pl_checkpoint_path}") if ssl_type == SSLTrainingType.BYOL.value or ssl_type == SSLTrainingType.BYOL: - # Here we need to indicate how many classes where used for linear evaluator at training time, to load the - # checkpoint (incl. linear evaluator) with strict = True - byol_module = SSLModelLoader(BYOLInnerEye, loaded_params["num_classes"]).load_from_checkpoint( - pl_checkpoint_path) + byol_module = BYOLInnerEye.load_from_checkpoint(pl_checkpoint_path) encoder = byol_module.target_network.encoder elif ssl_type == SSLTrainingType.SimCLR.value or ssl_type == SSLTrainingType.SimCLR: - simclr_module = SSLModelLoader(SimCLRInnerEye, loaded_params["num_classes"]).load_from_checkpoint( - pl_checkpoint_path) + simclr_module = SimCLRInnerEye.load_from_checkpoint(pl_checkpoint_path) encoder = simclr_module.encoder else: raise NotImplementedError(f"Unknown unsupervised model: {ssl_type}") @@ -101,25 +98,57 @@ def create_ssl_image_classifier(num_classes: int, return model -def SSLModelLoader(ssl_class: Any, num_classes: int) -> Any: +def get_from_list_or_singleton(items: Any, index: int, fail_if_out_of_range: bool = True) -> Any: """ - This class is a helper class for SSL model loading from checkpoints with strict=True. - We cannot simply load the class directly via do BYOLInnerEye().load_from_checkpoint("ckpt") with strict loading - because the checkpoint will contain the weights of the linear evaluator, but this one is defined outside of the - BYOLInnerEye class (as it is defined as a callback), hence we can only load the checkpoint if we manually re-add - the linear evaluator prior to loading. - - :param ssl_class: SSL object either BYOL or SimCLR. - :param num_classes: Number of target classes for the linear head. + Get an item with given index from a list. If `items` is not a list, it is possible to retrieve + that very element with index 0. This is due to PL's handling of optimizers: self.optimizers() is a single + Optimizer object if only one is used, but a list if multiple are provided. + + :param fail_if_out_of_range: If True, raise an IndexError if the given index is outside the bounds of the list. + If False, return None if the index is outside the bounds. + :param index: The index of the item to retrieve. + :param items: A list of items, or a single item. """ - from pl_bolts.models.self_supervised import SSLEvaluator - from InnerEye.ML.SSL.encoders import get_encoder_output_dim + if not isinstance(items, list): + items = [items] + if index < len(items): + return items[index] + if fail_if_out_of_range: + raise IndexError(f"Requested index {index}, but there are only {len(items)} items available.") + else: + return None + - class _wrap(ssl_class): # type: ignore - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.non_linear_evaluator = SSLEvaluator(n_input=get_encoder_output_dim(self), - n_classes=num_classes, - n_hidden=None) +def manual_optimization_step(pl: LightningModule, loss: torch.Tensor, optimizer_idx: int = 0) -> None: + """ + Execute a manual optimization step in the given PL module, with the provided loss value. This will ONLY update + the optimizer with the given index. The learning rate scheduler will be updated too, both when updates at step + and updates at epoch level are chosen. - return _wrap + :param pl: The module on which the optimization step should be run. + :param loss: The loss tensor. + :param optimizer_idx: The index of the optimizer where the optimization step should be taken. + """ + optimizer = get_from_list_or_singleton(pl.optimizers(), optimizer_idx) + optimizer.zero_grad() + pl.manual_backward(loss) + optimizer.step() + assert pl.trainer is not None, "No trainer has been set for this module yet?" + # Read out the full information about the LR scheduler from the trainer object - at module level from + # pl.lr_schedulers() we don't see the update frequency + scheduler_dict = get_from_list_or_singleton(pl.trainer.lr_schedulers, optimizer_idx, fail_if_out_of_range=False) + # If there is no scheduler, just skip. This should account for cases where the second optimizer has a + # fixed LR and no scheduler. + if scheduler_dict is None: + return + if scheduler_dict["frequency"] != 1: + NotImplementedError(f"Updates every {scheduler_dict['frequency']} steps/epochs is not implemented.") + interval = scheduler_dict["interval"] + scheduler = scheduler_dict["scheduler"] + if interval == "step": + scheduler.step() + elif interval == "epoch": + if pl.trainer.is_last_batch: + scheduler.step() + else: + raise ValueError(f"Update interval not recognized: {interval}") diff --git a/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py b/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py index aa2eeb741..8ca2cb58c 100644 --- a/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py +++ b/InnerEye/ML/configs/ssl/CIFAR_SSL_configs.py @@ -15,7 +15,7 @@ class CIFAR10SimCLR(SSLContainer): def __init__(self) -> None: super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10, linear_head_dataset_name=SSLDatasetName.CIFAR10, - ssl_training_batch_size=512, + ssl_training_batch_size=128, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.SimCLR, random_seed=1, @@ -32,7 +32,7 @@ class CIFAR10BYOL(SSLContainer): def __init__(self) -> None: super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10, linear_head_dataset_name=SSLDatasetName.CIFAR10, - ssl_training_batch_size=512, + ssl_training_batch_size=128, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.BYOL, random_seed=1, @@ -49,7 +49,7 @@ class CIFAR10CIFAR100BYOL(SSLContainer): def __init__(self) -> None: super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10, linear_head_dataset_name=SSLDatasetName.CIFAR100, - ssl_training_batch_size=512, + ssl_training_batch_size=64, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.BYOL, random_seed=1, diff --git a/InnerEye/ML/configs/ssl/CXR_SSL_configs.py b/InnerEye/ML/configs/ssl/CXR_SSL_configs.py index 4a47f1905..aea7ba4ad 100644 --- a/InnerEye/ML/configs/ssl/CXR_SSL_configs.py +++ b/InnerEye/ML/configs/ssl/CXR_SSL_configs.py @@ -29,13 +29,14 @@ def __init__(self) -> None: random_seed=1, recovery_checkpoint_save_interval=200, num_epochs=1000, - ssl_training_batch_size=1200, + ssl_training_batch_size=75, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.BYOL, use_balanced_binary_loss_for_linear_head=True, ssl_augmentation_config=path_encoder_augmentation_cxr, extra_azure_dataset_ids=[RSNA_AZURE_DATASET_ID], linear_head_augmentation_config=path_linear_head_augmentation_cxr) + self.pl_find_unused_parameters = True class NIH_RSNA_SimCLR(SSLContainer): def __init__(self) -> None: @@ -45,13 +46,14 @@ def __init__(self) -> None: random_seed=1, recovery_checkpoint_save_interval=200, num_epochs=1000, - ssl_training_batch_size=1200, + ssl_training_batch_size=75, ssl_encoder=EncoderName.resnet50, ssl_training_type=SSLTrainingType.SimCLR, use_balanced_binary_loss_for_linear_head=True, ssl_augmentation_config=path_encoder_augmentation_cxr, extra_azure_dataset_ids=[RSNA_AZURE_DATASET_ID], linear_head_augmentation_config=path_linear_head_augmentation_cxr) + self.pl_find_unused_parameters = True class CXRImageClassifier(SSLClassifierContainer): diff --git a/InnerEye/ML/configs/ssl/CovidContainers.py b/InnerEye/ML/configs/ssl/CovidContainers.py index 2e79d35b8..f091876d5 100644 --- a/InnerEye/ML/configs/ssl/CovidContainers.py +++ b/InnerEye/ML/configs/ssl/CovidContainers.py @@ -23,7 +23,7 @@ def __init__(self, recovery_checkpoint_save_interval=50, recovery_checkpoints_save_last_k=3, num_epochs=500, - ssl_training_batch_size=1200, # This runs with 16 gpus (4 nodes) + ssl_training_batch_size=75, # This runs with 16 gpus (4 nodes) num_workers=12, ssl_encoder=EncoderName.densenet121, ssl_training_type=SSLTrainingType.BYOL, diff --git a/InnerEye/ML/dataset/scalar_dataset.py b/InnerEye/ML/dataset/scalar_dataset.py index 7b7d1382a..977310604 100644 --- a/InnerEye/ML/dataset/scalar_dataset.py +++ b/InnerEye/ML/dataset/scalar_dataset.py @@ -102,7 +102,7 @@ def extract_label_classification(label_string: str, sample_id: str, num_classes: raise ValueError(f"Subject {sample_id}: Indices {out_of_range} are out of range, for number of classes " f"= {num_classes}") - one_hot_array = np.zeros(num_classes, dtype=np.float) + one_hot_array = np.zeros(num_classes, dtype=np.float) # type: ignore one_hot_array[classes] = 1.0 return one_hot_array.tolist() else: diff --git a/InnerEye/ML/deep_learning_config.py b/InnerEye/ML/deep_learning_config.py index 4a8161045..f86376ec2 100644 --- a/InnerEye/ML/deep_learning_config.py +++ b/InnerEye/ML/deep_learning_config.py @@ -590,6 +590,14 @@ class TrainerParams(param.Parameterized): param.Boolean(default=False, doc="Controls the PyTorch Lightning flag 'find_unused_parameters' for the DDP plugin. " "Setting it to True comes with a performance hit.") + pl_limit_train_batches: Optional[int] = \ + param.Integer(default=None, + doc="PyTorch Lightning trainer flag 'limit_train_batches': Limit the training dataset to the " + "given number of batches.") + pl_limit_val_batches: Optional[int] = \ + param.Integer(default=None, + doc="PyTorch Lightning trainer flag 'limit_val_batches': Limit the validation dataset to the " + "given number of batches.") @property def use_gpu(self) -> bool: @@ -609,8 +617,10 @@ def num_gpus_per_node(self) -> int: or restrict it to max_num_gpu, whichever is smaller. Returns 0 if running on a CPU device. """ import torch - num_gpus = torch.cuda.device_count() if self.use_gpu else 0 - logging.info(f"Number of available GPUs: {num_gpus}") + available_gpus = torch.cuda.device_count() + num_gpus = available_gpus if self.use_gpu else 0 + message_suffix = "" if self.use_gpu else ", but not using them because use_gpu == False" + logging.info(f"Number of available GPUs: {available_gpus}{message_suffix}") if 0 <= self.max_num_gpus < num_gpus: num_gpus = self.max_num_gpus logging.info(f"Restricting the number of GPUs to {num_gpus}") diff --git a/InnerEye/ML/lightning_base.py b/InnerEye/ML/lightning_base.py index 878379c09..e69c944c7 100644 --- a/InnerEye/ML/lightning_base.py +++ b/InnerEye/ML/lightning_base.py @@ -3,7 +3,6 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import logging -import numbers from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -20,21 +19,22 @@ from InnerEye.Common.type_annotations import DictStrFloat from InnerEye.ML.common import ModelExecutionMode from InnerEye.ML.config import SegmentationModelBase +from InnerEye.ML.dataset.full_image_dataset import convert_channels_to_file_paths from InnerEye.ML.deep_learning_config import DatasetParams, DeepLearningConfig, OutputParams, TrainerParams, \ WorkflowParams from InnerEye.ML.lightning_container import LightningContainer -from InnerEye.ML.lightning_loggers import StoringLogger +from InnerEye.ML.lightning_loggers import StoringLogger, log_on_epoch from InnerEye.ML.metrics import EpochTimers, MAX_ITEM_LOAD_TIME_SEC, store_epoch_metrics from InnerEye.ML.metrics_dict import DataframeLogger from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.utils import model_util +from InnerEye.ML.utils.csv_util import CSV_SUBJECT_HEADER from InnerEye.ML.utils.device_aware_module import DeviceAwareModule from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp from InnerEye.ML.utils.ml_util import RandomStateSnapshot, set_random_seed, validate_dataset_paths from InnerEye.ML.utils.model_util import generate_and_print_model_summary from InnerEye.ML.visualizers.patch_sampling import visualize_random_crops_for_dataset -from InnerEye.ML.utils.csv_util import CSV_SUBJECT_HEADER -from InnerEye.ML.dataset.full_image_dataset import convert_channels_to_file_paths + class TrainAndValDataLightning(LightningDataModule): """ @@ -314,7 +314,7 @@ def read_epoch_results_from_logger_and_store(self, epoch: int) -> None: Training and Validation metrics. """ if epoch >= 0: - if epoch in self.storing_logger.results: + if epoch in self.storing_logger.results_per_epoch: for is_training, prefix in [(True, TRAIN_PREFIX), (False, VALIDATION_PREFIX)]: metrics = self.storing_logger.extract_by_prefix(epoch, prefix) self.store_epoch_results(metrics, epoch, is_training) @@ -375,20 +375,19 @@ def log_on_epoch(self, :param name: The name of the metric to log :param value: The value of the metric. This can be a tensor, floating point value, or a Metric class. :param is_training: If true, give the metric a "train/" prefix, otherwise a "val/" prefix. - :param reduce_fx: The reduce function to apply after synchronizing the tensors across GPUs. + :param reduce_fx: The reduce function to apply to step values. Default: torch.mean :param sync_dist_op: The reduce operation to use when synchronizing the tensors across GPUs. This must be a value recognized by sync_ddp: Either 'None' to use 'sum' as aggregate, or 'mean' or 'avg' """ metric_name = name if isinstance(name, str) else name.value - if isinstance(value, numbers.Number): - value = torch.tensor(value, dtype=torch.float, device=self.device) prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX sync_dist = self.use_sync_dist if sync_dist_override is None else sync_dist_override - self.log(prefix + metric_name, value, - sync_dist=sync_dist, - on_step=False, on_epoch=True, - reduce_fx=reduce_fx, - sync_dist_op=sync_dist_op) + log_on_epoch(self, + name=prefix + metric_name, + value=value, + sync_dist=sync_dist, + reduce_fx=reduce_fx, + sync_dist_op=sync_dist_op) def store_epoch_results(self, metrics: DictStrFloat, epoch: int, is_training: bool) -> None: """ diff --git a/InnerEye/ML/lightning_loggers.py b/InnerEye/ML/lightning_loggers.py index 073bd3065..16235bb2a 100644 --- a/InnerEye/ML/lightning_loggers.py +++ b/InnerEye/ML/lightning_loggers.py @@ -2,46 +2,64 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Any, Dict, Iterable, List, Optional +import logging +import numbers +import operator +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional +import torch +from pytorch_lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.utilities import rank_zero_only +from torch.optim.lr_scheduler import _LRScheduler from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context from InnerEye.Common.metrics_constants import TRAIN_PREFIX, VALIDATION_PREFIX -from InnerEye.Common.type_annotations import DictStrFloat +from InnerEye.Common.type_annotations import DictStrFloat, DictStrFloatOrFloatList class StoringLogger(LightningLoggerBase): """ - A Pytorch Lightning logger that simply stores the metrics that are written to it. + A Pytorch Lightning logger that simply stores the metrics that are written to it, grouped by epoch. Used for diagnostic purposes in unit tests. """ def __init__(self) -> None: super().__init__() - self.results: Dict[int, DictStrFloat] = {} + self.results_per_epoch: Dict[int, DictStrFloatOrFloatList] = {} self.hyperparams: Any = None # Fields to store diagnostics for unit testing self.train_diagnostics: List[Any] = [] self.val_diagnostics: List[Any] = [] + self.results_without_epoch: List[DictStrFloat] = [] @rank_zero_only def log_metrics(self, metrics: DictStrFloat, step: Optional[int] = None) -> None: + logging.debug(f"StoringLogger step={step}: {metrics}") epoch_name = "epoch" if epoch_name not in metrics: - raise ValueError("Each of the logged metrics should have an 'epoch' key.") + # Metrics without an "epoch" key are logged during testing, for example + self.results_without_epoch.append(metrics) + return epoch = int(metrics[epoch_name]) del metrics[epoch_name] - if epoch in self.results: - current_results = self.results[epoch] - overlapping_keys = set(metrics.keys()).intersection(current_results.keys()) - if len(overlapping_keys) > 0: - raise ValueError(f"Unable to log metric with same name twice for epoch {epoch}: " - f"{', '.join(overlapping_keys)}") - current_results.update(metrics) + for key, value in metrics.items(): + if isinstance(value, int): + metrics[key] = float(value) + if epoch in self.results_per_epoch: + current_results = self.results_per_epoch[epoch] + for key, value in metrics.items(): + if key in current_results: + logging.debug(f"StoringLogger: appending results for metric {key}") + current_metrics = current_results[key] + if isinstance(current_metrics, list): + current_metrics.append(value) + else: + current_results[key] = [current_metrics, value] + else: + current_results[key] = value else: - self.results[epoch] = metrics + self.results_per_epoch[epoch] = metrics # type: ignore @rank_zero_only def log_hyperparams(self, params: Any) -> None: @@ -61,7 +79,7 @@ def epochs(self) -> Iterable[int]: """ Gets the epochs for which the present object holds any results. """ - return self.results.keys() + return self.results_per_epoch.keys() def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat: """ @@ -73,7 +91,7 @@ def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat have a name starting with `prefix`, and strip off the prefix. :return: A metrics dictionary. """ - epoch_results = self.results.get(epoch, None) + epoch_results = self.results_per_epoch.get(epoch, None) if epoch_results is None: raise KeyError(f"No results are stored for epoch {epoch}") filtered = {} @@ -83,8 +101,8 @@ def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat # filter is supplied and really matches the metric name if (not prefix_filter) or key.startswith(prefix_filter): stripped_key = key[len(prefix_filter):] - filtered[stripped_key] = value - return filtered + filtered[stripped_key] = value # type: ignore + return filtered # type: ignore def to_metrics_dicts(self, prefix_filter: str = "") -> Dict[int, DictStrFloat]: """ @@ -107,7 +125,14 @@ def get_metric(self, is_training: bool, metric_type: str) -> List[float]: :return: A list of floating point numbers, with one entry per entry in the the training or validation results. """ full_metric_name = (TRAIN_PREFIX if is_training else VALIDATION_PREFIX) + metric_type - return [self.results[epoch][full_metric_name] for epoch in self.epochs] + result = [] + for epoch in self.epochs: + value = self.results_per_epoch[epoch][full_metric_name] + if not isinstance(value, float): + raise ValueError(f"Expected a floating point value for metric {full_metric_name}, but got: " + f"{value}") + result.append(value) + return result def get_train_metric(self, metric_type: str) -> List[float]: """ @@ -152,6 +177,7 @@ def __init__(self) -> None: @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + logging.debug(f"AzureMLLogger step={step}: {metrics}") if self.is_azureml_run: for key, value in metrics.items(): RUN_CONTEXT.log(key, value) @@ -168,3 +194,77 @@ def name(self) -> Any: def version(self) -> int: return 0 + + +def log_on_epoch(module: LightningModule, + name: Optional[str] = None, + value: Optional[Any] = None, + metrics: Optional[Mapping[str, Any]] = None, + reduce_fx: Callable = torch.mean, + sync_dist: Optional[bool] = None, + sync_dist_op: Any = "mean") -> None: + """ + Write a dictionary with metrics and/or an individual metric as a name/value pair to the loggers of the given module. + Metrics are always logged upon epoch completion. + The metrics in question first synchronized across GPUs if DDP with >1 node is used, using the sync_dist_op + (default: mean). Afterwards, they are aggregated across all steps via the reduce_fx (default: mean). + Metrics that are fed in as plain numbers rather than tensors (for example, plain Python integers) are converted + to tensors before logging, to enable synchronization across GPUs if needed. + + :param name: The name of the metric to log. + :param value: The actual value of the metric to log. + :param metrics: A dictionary with metrics to log. + :param module: The PyTorch Lightning module where the metrics should be logged. + :param sync_dist: If not None, use this value for the sync_dist argument to module.log. If None, + set it automatically depending on the use of DDP. Set this to False if you want to log metrics that are only + available on Rank 0 of a DDP job. + :param reduce_fx: The reduce function to apply to the per-step values, after synchronizing the tensors across GPUs. + Default: torch.mean + :param sync_dist_op: The reduce operation to use when synchronizing the tensors across GPUs. This must be + a value recognized by sync_ddp: 'sum', 'mean', 'avg' + """ + assert module.trainer is not None, "No trainer is set for this module." + if operator.xor(name is None, value is None): + raise ValueError("Both or neither of 'name' and 'value' must be provided.") + sync_dist = module.trainer.world_size > 1 if sync_dist is None else sync_dist + metrics = metrics or {} + if name is not None: + metrics[name] = value # type: ignore + metrics_as_tensors = { + key: torch.tensor(value, dtype=torch.float, device=module.device) + if isinstance(value, numbers.Number) + else value + for key, value in metrics.items() + } + module.log_dict(metrics_as_tensors, + on_epoch=True, + on_step=False, + sync_dist=sync_dist, + reduce_fx=reduce_fx, + sync_dist_op=sync_dist_op) + + +def log_learning_rate(module: LightningModule, name: str = "learning_rate") -> None: + """ + Logs the learning rate(s) used by the given module. Multiple learning rate schedulers and/or multiple rates per + scheduler are supported. The learning rates are logged under the given name. If multiple scheduler and/or multiple + rates are used, a suffix like "/0/1" is added, to indicate the learning rate for scheduler 0, index 1, for example. + Learning rates are logged on epoch. + + :param module: The module for which the learning rates should be logged. + :param name: The name to use when logging the learning rates. + """ + schedulers = module.lr_schedulers() + if schedulers is None: + raise ValueError("Learning rate logging can only be used during training.") + single_scheduler = not isinstance(schedulers, list) + if single_scheduler: + schedulers = [schedulers] + lr_0 = schedulers[0].get_last_lr() # type: ignore + singleton_lr = single_scheduler and len(lr_0) == 1 + logged = {} + for i, scheduler in enumerate(schedulers): + for j, lr_j in enumerate(scheduler.get_last_lr()): # type: ignore + full_name = name if singleton_lr else f"{name}/{i}/{j}" + logged[full_name] = lr_j + log_on_epoch(module, metrics=logged) diff --git a/InnerEye/ML/metrics.py b/InnerEye/ML/metrics.py index a87dde368..ea8cb19b0 100644 --- a/InnerEye/ML/metrics.py +++ b/InnerEye/ML/metrics.py @@ -12,7 +12,7 @@ import SimpleITK as sitk import numpy as np -from numpy.core.numeric import NaN +from numpy.core.numeric import NaN # type: ignore import torch import torch.nn.functional as F from azureml.core import Run diff --git a/InnerEye/ML/metrics_dict.py b/InnerEye/ML/metrics_dict.py index fd2dc8fad..f519c482e 100644 --- a/InnerEye/ML/metrics_dict.py +++ b/InnerEye/ML/metrics_dict.py @@ -420,13 +420,13 @@ def get_accuracy_at05(self, hue: str = DEFAULT_HUE_KEY) -> float: label=self.get_labels(hue=hue)) @classmethod - def get_optimal_idx(cls, fpr: np.ndarray, tpr: np.ndarray) -> np.ndarray: + def get_optimal_idx(cls, fpr: np.ndarray, tpr: np.ndarray) -> int: """ Given a list of FPR and TPR values corresponding to different thresholds, compute the index which corresponds to the optimal threshold. """ optimal_idx = np.argmax(tpr - fpr) - return optimal_idx + return optimal_idx # type: ignore def get_metrics_at_optimal_cutoff(self, hue: str = DEFAULT_HUE_KEY) -> Tuple: """ @@ -705,8 +705,8 @@ def load_execution_mode_metrics_from_df(df: pd.DataFrame, result[mode][epoch] = ScalarMetricsDict(is_classification_metrics=is_classification_metrics, hues=hues) subjects = list(group[LoggingColumns.Patient.value].values) - predictions = group[LoggingColumns.ModelOutput.value].to_numpy(dtype=np.float) - labels = group[LoggingColumns.Label.value].to_numpy(dtype=np.float) + predictions = group[LoggingColumns.ModelOutput.value].to_numpy(dtype=np.float) # type: ignore + labels = group[LoggingColumns.Label.value].to_numpy(dtype=np.float) # type: ignore result[mode][epoch].add_predictions(subjects, predictions, labels, hue=hue) return result diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 287f4a0c1..e52d6b400 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -21,7 +21,7 @@ from InnerEye.ML.deep_learning_config import ARGS_TXT, VISUALIZATION_FOLDER from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning from InnerEye.ML.lightning_container import LightningContainer -from InnerEye.ML.lightning_loggers import AzureMLLogger, StoringLogger +from InnerEye.ML.lightning_loggers import AzureMLLogger, StoringLogger, log_on_epoch from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \ get_subject_output_file_per_rank @@ -87,22 +87,23 @@ class InnerEyeRecoveryCheckpointCallback(ModelCheckpoint): def __init__(self, container: LightningContainer): super().__init__(dirpath=str(container.checkpoint_folder), - monitor="epoch", + monitor="epoch_started", filename=RECOVERY_CHECKPOINT_FILE_NAME + "_{epoch}", period=container.recovery_checkpoint_save_interval, save_top_k=container.recovery_checkpoints_save_last_k, mode="max", save_last=False) - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, unused: bool = None) -> None: - pl_module.log(name="epoch", value=trainer.current_epoch) # type: ignore + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule, unused: bool = None) -> None: + # The metric to monitor must be logged on all ranks in distributed training + log_on_epoch(pl_module, name="epoch_started", value=trainer.current_epoch, sync_dist=False) # type: ignore def create_lightning_trainer(container: LightningContainer, resume_from_checkpoint: Optional[Path] = None, num_nodes: int = 1, **kwargs: Dict[str, Any]) -> \ - Tuple[Trainer, Optional[StoringLogger]]: + Tuple[Trainer, StoringLogger]: """ Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second @@ -139,12 +140,8 @@ def create_lightning_trainer(container: LightningContainer, logging.info(f"Using {num_gpus} GPUs per node with accelerator '{accelerator}'") tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="") loggers = [tensorboard_logger, AzureMLLogger()] - storing_logger: Optional[StoringLogger] - if isinstance(container, InnerEyeContainer): - storing_logger = StoringLogger() - loggers.append(storing_logger) - else: - storing_logger = None + storing_logger = StoringLogger() + loggers.append(storing_logger) # Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag. precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32 # The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark @@ -177,6 +174,8 @@ def create_lightning_trainer(container: LightningContainer, accelerator=accelerator, plugins=plugins, max_epochs=container.num_epochs, + limit_train_batches=container.pl_limit_train_batches or 1.0, + limit_val_batches=container.pl_limit_val_batches or 1.0, num_sanity_val_steps=container.pl_num_sanity_val_steps, callbacks=callbacks, logger=loggers, @@ -209,7 +208,7 @@ def start_resource_monitor(config: LightningContainer) -> ResourceMonitor: def model_train(checkpoint_path: Optional[Path], container: LightningContainer, - num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]: + num_nodes: int = 1) -> Tuple[Trainer, StoringLogger]: """ The main training loop. It creates the Pytorch model based on the configuration options passed in, creates a Pytorch Lightning trainer, and trains the model. diff --git a/InnerEye/ML/photometric_normalization.py b/InnerEye/ML/photometric_normalization.py index 4d34a8ac7..5116b307c 100644 --- a/InnerEye/ML/photometric_normalization.py +++ b/InnerEye/ML/photometric_normalization.py @@ -107,7 +107,7 @@ def transform(self, image: Union[np.ndarray, torch.Tensor], image_out = CTRange.transform(data=image, output_range=self.output_range, level=self.level, window=self.window, use_gpu=self.use_gpu) elif self.norm_method == PhotometricNormalizationMethod.TrimmedNorm: - image_out, status = normalize_trim(image, mask, + image_out, status = normalize_trim(image, mask, # type: ignore self.output_range, self.sharpen, self.trim_percentiles, self.debug_mode) self.status_of_most_recent_call = status @@ -120,7 +120,7 @@ def transform(self, image: Union[np.ndarray, torch.Tensor], return image_out -def simple_norm(image_in: np.ndarray, mask: np.ndarray, debug_mode: bool = False) -> np.array: +def simple_norm(image_in: np.ndarray, mask: np.ndarray, debug_mode: bool = False) -> np.ndarray: """ Normalizes a single image to have mean 0 and standard deviation 1 @@ -160,7 +160,7 @@ def normalize_trim(image: np.ndarray, output_range: Tuple[float, float] = (-1.0, 1.0), sharpen: float = 1.9, trim_percentiles: Tuple[float, float] = (2.0, 98.0), - debug_mode: bool = False) -> np.array: + debug_mode: bool = False) -> np.ndarray: """ Normalizes a single image to have mean 0 and standard deviation 1 Normalising occurs after percentile thresholds have been applied to strip out extreme values @@ -258,7 +258,7 @@ def mri_window(image_in: np.ndarray, output_range: Tuple[float, float] = (-1.0, 1.0), sharpen: float = 1.9, tail: Union[List[float], float] = 1.0, - debug_mode: bool = False) -> Tuple[np.array, str]: + debug_mode: bool = False) -> Tuple[np.ndarray, str]: """ This function takes an MRI Image, removes to first peak of values (air). Then a window range is found centered around the mean of the remaining values and with a range controlled by the standard deviation and the sharpen @@ -290,7 +290,7 @@ def mri_window(image_in: np.ndarray, in_mask = False else: maflat = mask.flatten() - in_mask = mask > 0 + in_mask = mask > 0 # type: ignore # Find Otsu's threshold for the values of the input image threshold = threshold_otsu(imflat) # Find window level @@ -309,8 +309,8 @@ def mri_window(image_in: np.ndarray, if mask is None: no_thresh = np.sum(imflat < threshold) no_high = np.sum(imout == output_range[1]) - pc_thresh = no_thresh / np.numel(imflat) * 100 - pc_high = no_high / np.numel(imflat) * 100 + pc_thresh = no_thresh / np.numel(imflat) * 100 # type: ignore + pc_high = no_high / np.numel(imflat) * 100 # type: ignore else: no_thresh = np.sum(imflat[maflat == 1] < threshold) no_high = np.sum(imout == output_range[1]) diff --git a/InnerEye/ML/run_ml.py b/InnerEye/ML/run_ml.py index 38dca7724..bcdb86bc8 100644 --- a/InnerEye/ML/run_ml.py +++ b/InnerEye/ML/run_ml.py @@ -16,6 +16,9 @@ from azureml._restclient.constants import RunStatus from azureml.core import Model, Run, model from azureml.data import FileDataset +from health.azure.azure_util import ENVIRONMENT_VERSION, create_run_recovery_id, merge_conda_files +from health.azure.datasets import get_or_create_dataset +from health.azure.himl import AzureRunInfo from pytorch_lightning import LightningModule, seed_everything from pytorch_lightning.core.datamodule import LightningDataModule from torch.utils.data import DataLoader @@ -43,6 +46,7 @@ FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER, ModelCategory, MultiprocessingStartMethod, load_checkpoint from InnerEye.ML.lightning_base import InnerEyeContainer from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer +from InnerEye.ML.lightning_loggers import StoringLogger from InnerEye.ML.metrics import InferenceMetrics, InferenceMetricsForSegmentation from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.model_inference_config import ModelInferenceConfig @@ -58,9 +62,6 @@ from InnerEye.ML.visualizers import activation_maps from InnerEye.ML.visualizers.plot_cross_validation import \ get_config_and_results_for_offline_runs, plot_cross_validation_from_files -from health.azure.azure_util import ENVIRONMENT_VERSION, create_run_recovery_id, merge_conda_files -from health.azure.datasets import get_or_create_dataset -from health.azure.himl import AzureRunInfo ModelDeploymentHookSignature = Callable[[LightningContainer, AzureConfig, Model, ModelProcessing], Any] PostCrossValidationHookSignature = Callable[[ModelConfigBase, Path], None] @@ -189,6 +190,7 @@ def __init__(self, self.model_deployment_hook = model_deployment_hook self.output_subfolder = output_subfolder self._has_setup_run = False + self.storing_logger: Optional[StoringLogger] = None def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None: """ @@ -389,9 +391,10 @@ def run(self) -> None: # train a new model if required if self.azure_config.train: with logging_section("Model training"): - model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(), - container=self.container, - num_nodes=self.azure_config.num_nodes) + _, storing_logger = model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(), + container=self.container, + num_nodes=self.azure_config.num_nodes) + self.storing_logger = storing_logger # Since we have trained the model further, let the checkpoint_handler object know so it can handle # checkpoints correctly. self.checkpoint_handler.additional_training_done() diff --git a/InnerEye/ML/runner.py b/InnerEye/ML/runner.py index b992a4bf7..f3ab0a8d0 100755 --- a/InnerEye/ML/runner.py +++ b/InnerEye/ML/runner.py @@ -133,6 +133,8 @@ def __init__(self, self.model_config: Optional[DeepLearningConfig] = None self.azure_config: AzureConfig = AzureConfig() self.lightning_container: LightningContainer = None # type: ignore + # This field stores the MLRunner object that has been created in the most recent call to the run() method. + self.ml_runner: Optional[MLRunner] = None def parse_and_load_model(self) -> ParserResult: """ @@ -379,11 +381,11 @@ def run_in_situ(self, azure_run_info: AzureRunInfo) -> None: # Set environment variables for multi-node training if needed. This function will terminate early # if it detects that it is not in a multi-node environment. set_environment_variables_for_multi_node() - ml_runner = self.create_ml_runner() - ml_runner.setup(azure_run_info) - ml_runner.start_logging_to_file() + self.ml_runner = self.create_ml_runner() + self.ml_runner.setup(azure_run_info) + self.ml_runner.start_logging_to_file() try: - ml_runner.run() + self.ml_runner.run() finally: disable_logging_to_file() diff --git a/InnerEye/ML/utils/dataset_util.py b/InnerEye/ML/utils/dataset_util.py index e71ed1e31..86e254d8d 100644 --- a/InnerEye/ML/utils/dataset_util.py +++ b/InnerEye/ML/utils/dataset_util.py @@ -41,7 +41,7 @@ def __init__(self, columns_and_possible_categories: OrderedDict[str, List[str]]) for col, value in columns_and_possible_categories.items(): # Fit only once during initialization with all possible values. if np.inf in value: - value.remove(np.inf) + value.remove(np.inf) # type: ignore self._encoder[col] = OneHotEncoder(handle_unknown='ignore').fit(np.array(value).reshape(-1, 1)) self._feature_length[col] = len(value) diff --git a/InnerEye/ML/utils/image_util.py b/InnerEye/ML/utils/image_util.py index dff703631..11caf4aff 100644 --- a/InnerEye/ML/utils/image_util.py +++ b/InnerEye/ML/utils/image_util.py @@ -120,8 +120,8 @@ def create_padding_vector() -> Tuple[TupleInt2, TupleInt2, TupleInt2]: """ Creates the padding vector. """ - diff = np.subtract(crop_size, output_size) - pad: List[int] = np.ceil(diff / 2.0).astype(int) + diff = np.subtract(crop_size, output_size) # type: ignore + pad: List[int] = np.ceil(diff / 2.0).astype(int).tolist() # type: ignore return (pad[0], diff[0] - pad[0]), (pad[1], diff[1] - pad[1]), (pad[2], diff[2] - pad[2]) if images is None: @@ -159,7 +159,7 @@ def create_padding_vector() -> Tuple[TupleInt2, TupleInt2, TupleInt2]: Creates the padding vector ceil(crop_size - output_size / 2) """ image_spatial_shape = images.shape[-3:] - diff = np.clip(np.subtract(output_size, image_spatial_shape), a_min=0, a_max=None) + diff = np.clip(np.subtract(output_size, image_spatial_shape), a_min=0, a_max=None) # type: ignore pad: List[int] = np.ceil(diff / 2.0).astype(int) return (pad[0], diff[0] - pad[0]), (pad[1], diff[1] - pad[1]), (pad[2], diff[2] - pad[2]) @@ -222,9 +222,9 @@ def posteriors_to_segmentation(posteriors: NumpyOrTorch) -> NumpyOrTorch: # add a batch dimension if required argmax_dim = 1 if len(posteriors.shape) == 5 else 0 - if torch.is_tensor(posteriors): + if isinstance(posteriors, torch.Tensor): try: - segmentation = posteriors.argmax(dim=argmax_dim) + segmentation = posteriors.argmax(dim=argmax_dim) # type: ignore except RuntimeError: # CUDA out of memory, presumably, so we move it to CPU and try again posteriors = posteriors.cpu() @@ -253,7 +253,7 @@ def largest_connected_components(img: np.ndarray, component_sizes = np.bincount(labeled_array.flatten()) # We don't want to count background component_sizes[0] = 0 - largest_component_indices: List[Union[int, np.array]] = [] + largest_component_indices: List[Union[int, np.ndarray]] = [] if deletion_limit is not None and deletion_limit < 0.5: # Find the indices of all components with sizes over the threshold - there can be more than one # (or there might be none, if all components are small). @@ -262,8 +262,8 @@ def largest_connected_components(img: np.ndarray, if not largest_component_indices: # We can get here either if we didn't run the "if" clause above, or if we did but found no components # of the required size. In either case, we want to return the largest component, whatever its size. - largest_component_indices = [np.argmax(component_sizes)] - out = np.zeros(img.shape, np.bool) + largest_component_indices = [np.argmax(component_sizes)] # type: ignore + out = np.zeros(img.shape, np.bool) # type: ignore for idx in largest_component_indices: out[labeled_array == idx] = True voxels_left = out.sum() @@ -686,8 +686,8 @@ def find_intersection_array_indices(indices1: Union[np.ndarray, Tuple[np.ndarray if len(indices1) != len(indices2) or len(indices1) != len(shape): raise ValueError("find_intersection_array_indices: " "Trying to compare indices from incompatible array shapes") - row_major_indices1 = np.ravel_multi_index(indices1, shape) - row_major_indices2 = np.ravel_multi_index(indices2, shape) + row_major_indices1 = np.ravel_multi_index(indices1, shape) # type: ignore + row_major_indices2 = np.ravel_multi_index(indices2, shape) # type: ignore intersection_in_row_major = np.intersect1d(row_major_indices1, row_major_indices2, assume_unique=True) intersection_indices = np.unravel_index(intersection_in_row_major, shape) return intersection_indices diff --git a/InnerEye/ML/utils/io_util.py b/InnerEye/ML/utils/io_util.py index cc13155c1..7b7b510ef 100644 --- a/InnerEye/ML/utils/io_util.py +++ b/InnerEye/ML/utils/io_util.py @@ -271,7 +271,7 @@ def load_dicom_image(path: PathOrString) -> np.ndarray: else: raise ValueError("Unknown value for DICOM tag 0028,0103 PixelRepresentation") # Return a float array, we may resize this in load_3d_images_and_stack, and interpolation will not work on int - return pixels.astype(np.float) + return pixels.astype(np.float) # type: ignore def load_hdf5_dataset_from_file(path_str: Path, dataset_name: str) -> np.ndarray: @@ -371,7 +371,7 @@ def from_numpy_crop_and_resize(array: np.ndarray) -> torch.Tensor: if load_segmentation: # Segmentations are loaded as UInt8. Convert to one-hot encoding as late as possible, # that is only before feeding into the model - segmentations.append(from_numpy_crop_and_resize(image_and_segmentation.segmentations)) + segmentations.append(from_numpy_crop_and_resize(image_and_segmentation.segmentations)) # type: ignore image_tensor = torch.stack(images, dim=0) if len(images) > 0 else torch.empty(0) segmentation_tensor = torch.stack(segmentations, dim=0) if len(segmentations) > 0 else torch.empty(0) @@ -431,7 +431,7 @@ def load_labels_from_dataset_source(dataset_source: PatientDatasetSource, check_ # label_list keeps track of missing ground truth channels for gt in dataset_source.ground_truth_channels: if gt is None: - label_list.append(np.full(image_size, np.NAN, ImageDataType)) + label_list.append(np.full(image_size, np.NAN, ImageDataType)) # type: ignore else: label_list.append(load_image(gt, ImageDataType.SEGMENTATION.value).image) labels = np.stack(label_list) @@ -496,7 +496,7 @@ def load_image(path: PathOrString, image_type: Optional[Type] = float) -> ImageW return ImageWithHeader(image, header) elif is_png(path): import imageio - image = imageio.imread(path).astype(np.float) + image = imageio.imread(path).astype(np.float) # type: ignore header = get_unit_image_header() return ImageWithHeader(image, header) raise ValueError(f"Invalid file type {path}") @@ -677,11 +677,12 @@ def store_as_nifti(image: np.ndarray, else: image = (image + 1) * 255 - image = sitk.GetImageFromArray(image.astype(image_type)) - image.SetSpacing(sitk.VectorDouble(reverse_tuple_float3(header.spacing))) # Spacing needs to be X Y Z - image.SetOrigin(header.origin) - image.SetDirection(header.direction) - sitk.WriteImage(image, str(file_name)) + itk_image = sitk.GetImageFromArray(image.astype(image_type)) + # Spacing needs to be X Y Z + itk_image.SetSpacing(sitk.VectorDouble(reverse_tuple_float3(header.spacing))) + itk_image.SetOrigin(header.origin) + itk_image.SetDirection(header.direction) + sitk.WriteImage(itk_image, str(file_name)) return Path(file_name) @@ -763,7 +764,7 @@ def create_dicom_series(folder: Path, size: TupleInt3, spacing: TupleFloat3) -> :param spacing: Final image spacing, as (column spacing, row spacing, slice spacing) (in mm). :return: The test data, a 3d ndarray of floats in the range [0, 1000.0). """ - data = np.random.uniform(high=1000, size=size).astype(np.float) + data = np.random.uniform(high=1000, size=size).astype(np.float) # type: ignore image = sitk.GetImageFromArray(data) image.SetSpacing(spacing) diff --git a/InnerEye/ML/utils/metrics_util.py b/InnerEye/ML/utils/metrics_util.py index 63a812dd1..e95f7ffbd 100644 --- a/InnerEye/ML/utils/metrics_util.py +++ b/InnerEye/ML/utils/metrics_util.py @@ -224,9 +224,9 @@ def r2_score(model_output: Union[torch.Tensor, np.ndarray], label: Union[torch.T Computes the coefficient of determination R2. Represents the proportion of variance explained by the (independent) variables in the model. R2 = 1 - Mean(SquaredErrors) / Variance(Labels) """ - if torch.is_tensor(label): + if isinstance(label, torch.Tensor): label = label.detach().cpu().numpy() - if torch.is_tensor(model_output): + if isinstance(model_output, torch.Tensor): model_output = model_output.detach().cpu().numpy() return sklearn_r2_score(label, model_output) @@ -256,14 +256,14 @@ def convert_input_and_label(model_output: Union[torch.Tensor, np.ndarray], Ensures that both model_output and label are tensors of dtype float32. :return a Tuple with model_output, label as float tensors. """ - if not torch.is_tensor(model_output): + if not isinstance(model_output, torch.Tensor): model_output = torch.tensor(model_output) - if not torch.is_tensor(label): + if not isinstance(label, torch.Tensor): label = torch.tensor(label) return model_output.float(), label.float() -def is_missing_ground_truth(ground_truth: np.array) -> bool: +def is_missing_ground_truth(ground_truth: np.ndarray) -> bool: """ calculate_metrics_per_class in metrics.py and plot_contours_for_all_classes in plotting.py both check whether there is ground truth missing using this simple check for NaN value at 0, 0, 0. diff --git a/InnerEye/ML/visualizers/reliability_curve.py b/InnerEye/ML/visualizers/reliability_curve.py index 261879b20..a8acc99ff 100644 --- a/InnerEye/ML/visualizers/reliability_curve.py +++ b/InnerEye/ML/visualizers/reliability_curve.py @@ -32,7 +32,7 @@ def plot_reliability_curve( if not isinstance(y_predict, list): y_predict = [y_predict] - y_true = [y_true] + y_true = [y_true] # type: ignore if not len(y_true) == len(y_predict): raise ValueError("y_true and y_predict are not of the same length") diff --git a/Tests/Common/test_commandline_parsing.py b/Tests/Common/test_commandline_parsing.py index df6779e92..eedb5874c 100644 --- a/Tests/Common/test_commandline_parsing.py +++ b/Tests/Common/test_commandline_parsing.py @@ -65,7 +65,7 @@ def test_create_ml_runner_args(is_container: bool, azure_run_info = AzureRunInfo(input_datasets=[None], output_datasets=[None], run=None, - is_running_in_azure=False, + is_running_in_azure_ml=False, output_folder=Path.cwd(), logs_folder=Path.cwd()) runner.run_in_situ(azure_run_info) diff --git a/Tests/ML/models/test_instantiate_models.py b/Tests/ML/models/test_instantiate_models.py index d8e3f6582..aaf211d3b 100644 --- a/Tests/ML/models/test_instantiate_models.py +++ b/Tests/ML/models/test_instantiate_models.py @@ -10,6 +10,7 @@ from InnerEye.Common.common_util import logging_to_stdout, namespace_to_path from InnerEye.Common.output_directories import OutputFolderForTests +from InnerEye.ML.configs.ssl.CXR_SSL_configs import NIH_RSNA_BYOL from InnerEye.ML.utils.config_loader import ModelConfigLoader from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling, generate_and_print_model_summary from Tests.ML.configs.DummyModel import DummyModel @@ -171,3 +172,15 @@ def test_run_model_with_invalid_trainer_arguments(test_output_dirs: OutputFolder with pytest.raises(Exception) as ex: model_train_unittest(config=None, dirs=test_output_dirs, lightning_container=container) assert "no_such_argument" in str(ex) + + +def test_load_container_with_limit_batches() -> None: + """ + Test if we can load a container and override the pl_limit_train_batches flag + """ + runner = default_runner() + args = ["", "--model=NIH_RSNA_BYOL", "--pl_limit_train_batches=1"] + with mock.patch("sys.argv", args): + runner.parse_and_load_model() + assert isinstance(runner.lightning_container, NIH_RSNA_BYOL) + assert runner.lightning_container.pl_limit_train_batches == 1 diff --git a/Tests/ML/test_download_upload.py b/Tests/ML/test_download_upload.py index c7d6884ce..35ae0fb84 100644 --- a/Tests/ML/test_download_upload.py +++ b/Tests/ML/test_download_upload.py @@ -164,7 +164,7 @@ def _test_mount_for_lightning_container(test_output_dirs: OutputFolderForTests, runner.setup(azure_run_info=AzureRunInfo(input_datasets=path_from_aml, output_datasets=[], run=None, - is_running_in_azure=False, + is_running_in_azure_ml=False, output_folder=Path(), logs_folder=Path() )) diff --git a/Tests/ML/test_model_training.py b/Tests/ML/test_model_training.py index 413de42e0..0990a2a27 100644 --- a/Tests/ML/test_model_training.py +++ b/Tests/ML/test_model_training.py @@ -7,11 +7,13 @@ import shutil from pathlib import Path from typing import Any, Dict, List +from unittest import mock import h5py import numpy as np import pandas as pd import pytest +import torch from torch.utils.data import DataLoader from InnerEye.Common import fixed_paths @@ -26,7 +28,7 @@ from InnerEye.ML.configs.classification.DummyClassification import DummyClassification from InnerEye.ML.dataset.sample import CroppedSample from InnerEye.ML.deep_learning_config import DeepLearningConfig -from InnerEye.ML.lightning_loggers import StoringLogger +from InnerEye.ML.lightning_loggers import StoringLogger, log_learning_rate, log_on_epoch from InnerEye.ML.model_training import aggregate_and_create_subject_metrics_file from InnerEye.ML.models.losses.mixture import MixtureLoss from InnerEye.ML.utils.io_util import load_nifti_image @@ -359,3 +361,135 @@ def test_aggregate_and_create_subject_metrics_file(test_output_dirs: OutputFolde written_lines = pd.read_csv(outputs_folder / mode / SUBJECT_METRICS_FILE_NAME) expected_lines = pd.read_csv(outputs_folder / mode / "expected_metrics.csv") assert written_lines.equals(expected_lines) + + +def test_storing_logger() -> None: + """ + Test if the StoringLogger can correctly handle multiple metrics of the same name logged per epoch. + """ + logger = StoringLogger() + key1 = "key" + key2 = "key2" + value1 = 3.14 + value2 = 2.71 + value3 = 100.0 + assert value1 != value2 + epoch = 1 + # Add metrics in the same epoch in two calls, so that we test both the cases where the epoch is already present, + # and where not + logger.log_metrics({"epoch": 1, key1: value1}) + logger.log_metrics({"epoch": 1, key2: value2}) + # All results for epoch 1 should be collated into a single dictionary + assert logger.extract_by_prefix(epoch=epoch) == {key1: value1, key2: value2} + # When updating a metric that already exists, the result should not be a float anymore but a list. + logger.log_metrics({"epoch": epoch, key1: value3}) + assert logger.extract_by_prefix(epoch=epoch) == {key1: [value1, value3], key2: value2} + # Add more metrics for key1, so that we also test the case that the results are already a list + logger.log_metrics({"epoch": epoch, key1: value3}) + assert logger.extract_by_prefix(epoch=epoch) == {key1: [value1, value3, value3], key2: value2} + # Add metrics that don't have an epoch key: This happens for example during testing with trainer.test + other_metrics1 = {"foo": 1.0} + other_metrics2 = {"foo": 2.0} + logger.log_metrics(other_metrics1) + logger.log_metrics(other_metrics2) + assert logger.results_without_epoch == [other_metrics1, other_metrics2] + + +def test_log_on_epoch() -> None: + """ + Tests if the helper function to log metrics per epoch works. + """ + module = mock.MagicMock() + module.trainer = None + with pytest.raises(AssertionError) as ex1: + log_on_epoch(module, metrics={"foo": 1}) + assert "No trainer is set" in str(ex1) + module.trainer = mock.MagicMock() + module.trainer.world_size = 1 + with pytest.raises(ValueError) as ex2: + log_on_epoch(module, name="foo") + assert "'name' and 'value' must be provided" in str(ex2) + with pytest.raises(ValueError) as ex3: + log_on_epoch(module, value=1.0) + assert "'name' and 'value' must be provided" in str(ex3) + foo_value = 1 + metrics = {"bar": torch.tensor(2.0)} + module.device = 'cpu' + module.log_dict = mock.MagicMock() + log_on_epoch(module, name="foo", value=foo_value, metrics=metrics) + # Test if all metrics that are not tensors are converted to floating point tensors + actual_args = module.log_dict.call_args + actual_metrics = actual_args[0][0] + for metric_name in ["foo", "bar"]: + assert metric_name in actual_metrics, f"Metric missing: {metric_name}" + assert torch.is_tensor(actual_metrics[metric_name]), f"Metric {metric_name}: not a tensor" + assert actual_metrics[metric_name].dtype == torch.float, f"Metric {metric_name}: should be float tensor" + assert actual_metrics["foo"].item() == float(foo_value) + # Default arguments for the call to module.log + assert actual_args[1] == {'on_epoch': True, + 'on_step': False, + 'reduce_fx': torch.mean, + 'sync_dist': False, + 'sync_dist_op': 'mean'}, "Failed for world_size==1" + # Test if sync_dist is computed correctly from world size: world size is now 2, so sync_dist should be True + module.trainer.world_size = 2 + log_on_epoch(module, metrics=metrics) + assert module.log_dict.call_args[1] == {'on_epoch': True, + 'on_step': False, + 'reduce_fx': torch.mean, + 'sync_dist': True, + 'sync_dist_op': 'mean'}, "Failed for world_size==2" + # Test if overrides for sync_dist and the other aggregation args are passed correctly + module.trainer.world_size = 2 + log_on_epoch(module, metrics=metrics, reduce_fx="reduce", sync_dist=False, sync_dist_op="nothing") # type: ignore + assert module.log_dict.call_args[1] == {'on_epoch': True, + 'on_step': False, + 'sync_dist': False, + 'reduce_fx': "reduce", + 'sync_dist_op': "nothing"}, "Failed for sync_dist==True" + module.trainer.world_size = 1 + log_on_epoch(module, metrics=metrics, reduce_fx="reduce", sync_dist=True, sync_dist_op="nothing") # type: ignore + assert module.log_dict.call_args[1] == {'on_epoch': True, + 'on_step': False, + 'sync_dist': True, + 'reduce_fx': "reduce", + 'sync_dist_op': "nothing"}, "Failed for sync_dist==True" + + +def test_log_learning_rate_singleton() -> None: + """ + Test the method that logs learning rates, when there is a single LR scheduler. + """ + module = mock.MagicMock() + module.lr_schedulers = mock.MagicMock(return_value=None) + with pytest.raises(ValueError) as ex: + log_learning_rate(module) + assert "can only be used during training" in str(ex) + scheduler = mock.MagicMock() + lr = 1.234 + scheduler.get_last_lr = mock.MagicMock(return_value=[lr]) + module.lr_schedulers = mock.MagicMock(return_value=scheduler) + with mock.patch("InnerEye.ML.lightning_loggers.log_on_epoch") as mock_log_on_epoch: + log_learning_rate(module) + assert mock_log_on_epoch.call_args[0] == (module,) + assert mock_log_on_epoch.call_args[1] == {'metrics': {'learning_rate': lr}} + + +def test_log_learning_rate_multiple() -> None: + """ + Test the method that logs learning rates, when there are multiple schedulers with non-scalar return values. + """ + scheduler1 = mock.MagicMock() + lr1 = [1] + scheduler1.get_last_lr = mock.MagicMock(return_value=lr1) + scheduler2 = mock.MagicMock() + lr2 = [2, 3] + scheduler2.get_last_lr = mock.MagicMock(return_value=lr2) + module = mock.MagicMock() + module.lr_schedulers = mock.MagicMock(return_value=[scheduler1, scheduler2]) + with mock.patch("InnerEye.ML.lightning_loggers.log_on_epoch") as mock_log_on_epoch: + log_learning_rate(module, name="foo") + assert mock_log_on_epoch.call_args[0] == (module,) + assert mock_log_on_epoch.call_args[1] == {'metrics': {'foo/0/0': lr1[0], + 'foo/1/0': lr2[0], + 'foo/1/1': lr2[1]}} diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 556305e5d..b2b356714 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import math from pathlib import Path from typing import Dict from unittest import mock @@ -11,6 +12,7 @@ import pytest import torch from pl_bolts.models.self_supervised.resnets import ResNet +from torch.optim.lr_scheduler import _LRScheduler from InnerEye.Common import fixed_paths from InnerEye.Common.common_util import is_windows @@ -20,6 +22,8 @@ from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier +from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import EVALUATOR_STATE_NAME, OPTIMIZER_STATE_NAME, \ + SSLOnlineEvaluatorInnerEye from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier @@ -55,6 +59,29 @@ def default_runner() -> Runner: yaml_config_file=fixed_paths.SETTINGS_YAML_FILE) +def _compare_stored_metrics(runner: Runner, expected_metrics: Dict[str, float], abs: float = 1e-5) -> None: + """ + Checks if the StoringLogger in the given runner holds all the expected metrics as results of training + epoch 0, up to a given absolute precision. + + :param runner: The Innereye runner. + :param expected_metrics: A dictionary with all metrics that are expected to be present. + """ + assert runner.ml_runner is not None + assert runner.ml_runner.storing_logger is not None + print(f"Actual metrics in epoch 0: {runner.ml_runner.storing_logger.results_per_epoch[0]}") + print(f"Expected metrics: {expected_metrics}") + for metric, expected in expected_metrics.items(): + actual = runner.ml_runner.storing_logger.results_per_epoch[0][metric] + if isinstance(actual, float): + if math.isnan(expected): + assert math.isnan(actual), f"Metric {metric}: Expected NaN, but got: {actual}" + else: + assert actual == pytest.approx(expected, abs=abs), f"Mismatch for metric {metric}" + else: + assert actual == expected, f"Mismatch for metric {metric}" + + common_test_args = ["", "--is_debug_model=True", "--num_epochs=1", "--ssl_training_batch_size=10", "--linear_head_batch_size=5", "--num_workers=0"] @@ -70,8 +97,9 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: - training of image classifier for one epoch. """ args = common_test_args + ["--model=CIFAR10SimCLR"] + runner = default_runner() with mock.patch("sys.argv", args): - loaded_config, actual_run = default_runner().run() + loaded_config, actual_run = runner.run() assert loaded_config is not None assert isinstance(loaded_config.model, SimCLRInnerEye) assert loaded_config.encoder_output_dim == 2048 @@ -79,12 +107,38 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: assert loaded_config.num_epochs == 1 assert loaded_config.recovery_checkpoint_save_interval == 200 assert loaded_config.ssl_training_type == SSLTrainingType.SimCLR - assert loaded_config.online_eval.num_classes == 10 + assert loaded_config.online_eval_callback.num_classes == 10 + assert loaded_config.online_eval_callback.dataset == SSLDatasetName.CIFAR10.value assert loaded_config.ssl_training_dataset_name == SSLDatasetName.CIFAR10 - assert loaded_config.online_eval.dataset == SSLDatasetName.CIFAR10.value assert not loaded_config.use_balanced_binary_loss_for_linear_head assert isinstance(loaded_config.model.encoder.cnn_model, ResNet) + + # Check the metrics that were recorded during training + expected_metrics = { + 'epoch_started': 0.0, + 'simclr/train/loss': 2.953442335128784, + 'simclr/learning_rate': 0.0, + 'ssl_online_evaluator/train/loss': 2.285637378692627, + 'ssl_online_evaluator/train/online_AccuracyAtThreshold05': 0.0, + 'simclr/val/loss': 2.8646411895751953, + 'ssl_online_evaluator/val/loss': 2.2882637977600098, + 'ssl_online_evaluator/val/AccuracyAtThreshold05': 0.0 + } + + _compare_stored_metrics(runner, expected_metrics) + + # Check that the checkpoint contains both the optimizer for the embedding and for the linear head checkpoint_path = loaded_config.outputs_folder / "checkpoints" / "best_checkpoint.ckpt" + checkpoint = torch.load(checkpoint_path) + assert len(checkpoint["optimizer_states"]) == 1 + assert len(checkpoint["lr_schedulers"]) == 1 + assert "callbacks" in checkpoint + assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"] + callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye] + assert OPTIMIZER_STATE_NAME in callback_state + assert EVALUATOR_STATE_NAME in callback_state + + # Now run the actual SSL classifier off the stored checkpoint args = common_test_args + ["--model=SSLClassifierCIFAR", f"--local_ssl_weights_path={checkpoint_path}"] with mock.patch("sys.argv", args): loaded_config, actual_run = default_runner().run() @@ -93,6 +147,7 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: assert loaded_config.model.class_weights is None assert loaded_config.model.num_classes == 10 + @pytest.mark.skipif(is_windows(), reason="Too slow on windows") def test_load_innereye_ssl_container_cifar10_cifar100_resnet_byol() -> None: """ @@ -127,8 +182,8 @@ def test_innereye_ssl_container_rsna() -> None: loaded_config, actual_run = runner.run() assert loaded_config is not None assert isinstance(loaded_config.model, BYOLInnerEye) - assert loaded_config.online_eval.dataset == SSLDatasetName.RSNAKaggleCXR.value - assert loaded_config.online_eval.num_classes == 2 + assert loaded_config.online_eval_callback.dataset == SSLDatasetName.RSNAKaggleCXR.value + assert loaded_config.online_eval_callback.num_classes == 2 assert loaded_config.ssl_training_dataset_name == SSLDatasetName.NIHCXR assert loaded_config.ssl_training_type == SSLTrainingType.BYOL assert loaded_config.encoder_output_dim == 1024 # DenseNet output size @@ -145,6 +200,21 @@ def test_innereye_ssl_container_rsna() -> None: SSLDataModuleType.ENCODER].augmentation_params.preprocess.center_crop_size == 224 assert loaded_config.datamodule_args[SSLDataModuleType.ENCODER].augmentation_params.augmentation.use_random_crop assert loaded_config.datamodule_args[SSLDataModuleType.ENCODER].augmentation_params.augmentation.use_random_affine + expected_metrics = { + 'epoch_started': 0.0, + 'byol/train/loss': 0.00401744619011879, + 'byol/tau': 0.9899999499320984, + 'ssl_online_evaluator/train/loss': 0.6889733672142029, + 'ssl_online_evaluator/train/online_AreaUnderRocCurve': 0.5, + 'ssl_online_evaluator/train/online_AreaUnderPRCurve': 0.699999988079071, + 'ssl_online_evaluator/train/online_AccuracyAtThreshold05': 0.4000000059604645, + 'byol/val/loss': -0.07644838094711304, + 'ssl_online_evaluator/val/loss': 0.69798344373703, + 'ssl_online_evaluator/val/AreaUnderRocCurve': math.nan, + 'ssl_online_evaluator/val/AreaUnderPRCurve': math.nan, + 'ssl_online_evaluator/val/AccuracyAtThreshold05': 0.0 + } + _compare_stored_metrics(runner, expected_metrics) # Check that we are able to load the checkpoint and create classifier model checkpoint_path = loaded_config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX @@ -159,3 +229,41 @@ def test_innereye_ssl_container_rsna() -> None: assert loaded_config.model.freeze_encoder assert torch.isclose(loaded_config.model.class_weights, torch.tensor([0.21, 0.79]), atol=1e-6).all() # type: ignore assert loaded_config.model.num_classes == 2 + + +def test_simclr_lr_scheduler() -> None: + """ + Test if the LR scheduler has the expected warmup behaviour. + """ + num_samples = 100 + batch_size = 20 + gpus = 1 + max_epochs = 10 + warmup_epochs = 2 + model = SimCLRInnerEye(encoder_name="resnet18", dataset_name="CIFAR10", + gpus=gpus, num_samples=num_samples, batch_size=batch_size, + max_epochs=max_epochs, warmup_epochs=warmup_epochs) + # The LR scheduler used here works per step. Scheduler computes the total number of steps, in this example that's 5 + train_iters_per_epoch = num_samples / (batch_size * gpus) + assert model.train_iters_per_epoch == train_iters_per_epoch + # Mock a second optimizer that is normally created in the SSL container + linear_head_optimizer = mock.MagicMock() + model.online_eval_optimizer = linear_head_optimizer + # Retrieve the scheduler and iterate it + _, scheduler_list = model.configure_optimizers() + assert isinstance(scheduler_list[0], dict) + assert scheduler_list[0]["interval"] == "step" + scheduler = scheduler_list[0]["scheduler"] + assert isinstance(scheduler, _LRScheduler) + lr = [] + for i in range(0, int(max_epochs * train_iters_per_epoch)): + scheduler.step() + lr.append(scheduler.get_last_lr()[0]) + # The highest learning rate is expected after the warmup epochs + highest_lr = np.argmax(lr) + assert highest_lr == int(warmup_epochs * train_iters_per_epoch - 1) + + for i in range(0, highest_lr): + assert lr[i] < lr[i + 1], f"Not strictly monotonically increasing at index {i}" + for i in range(highest_lr, len(lr) - 1): + assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}" diff --git a/docs/WSL.md b/docs/WSL.md index c0b581ec6..db61458a6 100644 --- a/docs/WSL.md +++ b/docs/WSL.md @@ -73,11 +73,13 @@ find * -name '*.pyc' | xargs -d'\n' rm` - https://www.jetbrains.com/help/pycharm/using-wsl-as-a-remote-interpreter.html - You might need to reset all your firewall settings to make the debugger work with PyCharm. This can be done with these PowerShell commands (as Administrator): ``` -Remove-NetFirewallRule -$myIp = (Ubuntu1804 run "cat /etc/resolv.conf | grep nameserver | cut -d' ' -f2") +$myIp = (Ubuntu2004 run "cat /etc/resolv.conf | grep nameserver | cut -d' ' -f2") New-NetFirewallRule -DisplayName "WSL" -Direction Inbound -LocalAddress $myIp -Action Allow ``` - Then (re)start PyCharm. If asked whether to give it permission to communicate over domain, private and public networks, make sure all three are ticked. +- If you are still struggling with the firewall rules, consider removing all your current firewall rules, by running + `Remove-NetFirewallRule` in the PowerShell. WARNING: This will remove all your present firewall rules, and you may + need to repeat the firewall setup for other programs that you have installed! ## Configure VSCode - https://code.visualstudio.com/docs/remote/wsl diff --git a/environment.yml b/environment.yml index b089ad486..c98883695 100644 --- a/environment.yml +++ b/environment.yml @@ -23,7 +23,7 @@ dependencies: - gitpython==3.1.7 - gputil==1.4.0 - h5py==2.10.0 - - hi-ml-azure>=0.1 + - hi-ml-azure>=0.1.8 - InnerEye-DICOM-RT==1.0.1 - joblib==0.16.0 - jupyter==1.0.0