Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ jobs that run in AzureML.
- ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets.
- ([#589](https://github.com/microsoft/InnerEye-DeepLearning/pull/589)) Add `LightningContainer.update_azure_config()`
hook to enable overriding `AzureConfig` parameters from a container (e.g. `experiment_name`, `cluster`, `num_nodes`).
-([#603](https://github.com/microsoft/InnerEye-DeepLearning/pull/603)) Add histopathology module
-([#614](https://github.com/microsoft/InnerEye-DeepLearning/pull/614)) Checkpoint downloading falls back to looking into AzureML if no checkpoints on disk
-([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets
-([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests
- ([#617](https://github.com/microsoft/InnerEye-DeepLearning/pull/617)) Commandline flag `pl_check_val_every_n_epoch` to control how often validation is happening
- ([#603](https://github.com/microsoft/InnerEye-DeepLearning/pull/603)) Add histopathology module
- ([#614](https://github.com/microsoft/InnerEye-DeepLearning/pull/614)) Checkpoint downloading falls back to looking into AzureML if no checkpoints on disk
- ([#613](https://github.com/microsoft/InnerEye-DeepLearning/pull/613)) Add additional tests for histopathology datasets
- ([#616](https://github.com/microsoft/InnerEye-DeepLearning/pull/616)) Add more histopathology configs and tests

### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
Expand All @@ -61,6 +62,7 @@ gets uploaded to AzureML, by skipping all test folders.
- ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) Pytorch is now non-deterministic by default. Upgrade to AzureML-SDK 1.36
- ([#566](https://github.com/microsoft/InnerEye-DeepLearning/pull/566)) Update `hi-ml` dependency to `hi-ml-azure`.
- ([#572](https://github.com/microsoft/InnerEye-DeepLearning/pull/572)) Updated to new version of hi-ml package
- ([#617](https://github.com/microsoft/InnerEye-DeepLearning/pull/617)) Provide an easier way for LightningContainers to add callbacks.
- ([#596](https://github.com/microsoft/InnerEye-DeepLearning/pull/596)) Add `cudatoolkit=11.1` specification to environment.yml.
- ([#615](https://github.com/microsoft/InnerEye-DeepLearning/pull/615)) Minor changes to checkpoint download from AzureML.
- ([#605](https://github.com/microsoft/InnerEye-DeepLearning/pull/605)) Make build jobs deterministic for regression testing.
Expand Down
9 changes: 4 additions & 5 deletions InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import param
from pytorch_lightning import LightningModule
from pytorch_lightning import Callback, LightningModule
from yacs.config import CfgNode

from InnerEye.ML.SSL.datamodules_and_datasets.cifar_datasets import InnerEyeCIFAR10, InnerEyeCIFAR100
Expand Down Expand Up @@ -266,12 +266,11 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],

return train_transforms, val_transforms

def get_trainer_arguments(self) -> Dict[str, Any]:
def get_callbacks(self) -> List[Callback]:
self.online_eval = SSLOnlineEvaluatorInnerEye(class_weights=self.data_module.class_weights, # type: ignore
z_dim=self.encoder_output_dim,
num_classes=self.data_module.num_classes, # type: ignore
dataset=self.linear_head_dataset_name.value, # type: ignore
drop_p=0.2,
learning_rate=self.learning_rate_linear_head_during_ssl_training)
trainer_kwargs: Dict[str, Any] = {"callbacks": self.online_eval}
return trainer_kwargs
return [self.online_eval]
11 changes: 6 additions & 5 deletions InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import Any, Dict
from typing import List

import param
from pytorch_lightning import Callback

from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import InnerEyeVisionDataModule
from InnerEye.ML.SSL.lightning_containers.ssl_container import InnerEyeDataModuleTypes, SSLContainer
Expand Down Expand Up @@ -63,7 +64,7 @@ def get_data_module(self) -> InnerEyeDataModuleTypes:
self.data_module.class_weights = self.data_module.compute_class_weights()
return self.data_module

def get_trainer_arguments(self) -> Dict[str, Any]:
# This class inherits from SSLContainer, where the get_trainer_arguments adds the online evaluator callback.
# We don't need that for the classifier, hence need to return an empty set of trainer arguments.
return {}
def get_callbacks(self) -> List[Callback]:
# This class inherits from SSLContainer, where the get_callbacks method adds the online evaluator.
# We don't need that for the classifier.
return []
3 changes: 3 additions & 0 deletions InnerEye/ML/deep_learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,9 @@ class TrainerParams(param.Parameterized):
param.String(default=None,
doc="The value to use for the 'profiler' argument for the Lightning trainer. "
"Set to either 'simple', 'advanced', or 'pytorch'")
pl_check_val_every_n_epoch: int = \
param.Integer(default=1,
doc="PyTorch Lightning trainer flag 'check_val_every_n_epoch': Run validation every N epochs.")
monitor_gpu: bool = param.Boolean(default=False,
doc="If True, add the GPUStatsMonitor callback to the Lightning trainer object. "
"This will write GPU utilization metrics every 50 batches by default.")
Expand Down
8 changes: 7 additions & 1 deletion InnerEye/ML/lightning_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import param
import torch
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import Callback, LightningDataModule, LightningModule
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from azureml.core import ScriptRunConfig
Expand Down Expand Up @@ -208,6 +208,12 @@ def get_trainer_arguments(self) -> Dict[str, Any]:
"""
return dict()

def get_callbacks(self) -> List[Callback]:
"""
Gets additional callbacks that the trainer should use when training this model.
"""
return []

def get_parameter_search_hyperdrive_config(self, _: ScriptRunConfig) -> HyperDriveConfig: # type: ignore
"""
Parameter search is not implemented. It should be implemented in a sub class if needed.
Expand Down
23 changes: 12 additions & 11 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypeVar
from typing import Any, List, Optional, Tuple, TypeVar

from health_azure.utils import is_global_rank_zero, is_local_rank_zero
from health_ml.utils import AzureMLLogger, AzureMLProgressBar, BatchTimeCallback, log_on_epoch
from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
Expand All @@ -26,6 +24,8 @@
from InnerEye.ML.lightning_loggers import StoringLogger
from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \
get_subject_output_file_per_rank
from health_azure.utils import is_global_rank_zero, is_local_rank_zero
from health_ml.utils import AzureMLLogger, AzureMLProgressBar, BatchTimeCallback, log_on_epoch

TEMP_PREFIX = "temp/"

Expand Down Expand Up @@ -78,8 +78,7 @@ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, unuse

def create_lightning_trainer(container: LightningContainer,
resume_from_checkpoint: Optional[Path] = None,
num_nodes: int = 1,
**kwargs: Dict[str, Any]) -> \
num_nodes: int = 1) -> \
Tuple[Trainer, StoringLogger]:
"""
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
Expand All @@ -88,7 +87,6 @@ def create_lightning_trainer(container: LightningContainer,
:param container: The container with model and data.
:param resume_from_checkpoint: If provided, training resumes from this checkpoint point.
:param num_nodes: The number of nodes to use in distributed training.
:param kwargs: Any additional keyowrd arguments will be passed to the constructor of Trainer.
:return: A tuple [Trainer object, diagnostic logger]
"""
logging.debug(f"resume_from_checkpoint: {resume_from_checkpoint}")
Expand Down Expand Up @@ -141,12 +139,15 @@ def create_lightning_trainer(container: LightningContainer,
logging.info("Adding monitoring for GPU utilization")
callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True))
# Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers
if "callbacks" in kwargs:
more_callbacks = kwargs.pop("callbacks")
additional_args = container.get_trainer_arguments()
# Callbacks can be specified via the "callbacks" argument (the legacy behaviour) or the new get_callbacks method
if "callbacks" in additional_args:
more_callbacks = additional_args.pop("callbacks")
if isinstance(more_callbacks, list):
callbacks.extend(more_callbacks) # type: ignore
else:
callbacks.append(more_callbacks) # type: ignore
callbacks.extend(container.get_callbacks())
is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate
if is_azureml_run:
Expand All @@ -168,6 +169,7 @@ def create_lightning_trainer(container: LightningContainer,
limit_train_batches=container.pl_limit_train_batches or 1.0,
limit_val_batches=container.pl_limit_val_batches or 1.0,
num_sanity_val_steps=container.pl_num_sanity_val_steps,
check_val_every_n_epoch=container.pl_check_val_every_n_epoch,
callbacks=callbacks,
logger=loggers,
progress_bar_refresh_rate=progress_bar_refresh_rate,
Expand All @@ -178,7 +180,7 @@ def create_lightning_trainer(container: LightningContainer,
terminate_on_nan=container.detect_anomaly,
profiler=container.pl_profiler,
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
**kwargs)
**additional_args)
return trainer, storing_logger


Expand Down Expand Up @@ -242,8 +244,7 @@ def model_train(checkpoint_path: Optional[Path],
seed_everything(container.get_effective_random_seed())
trainer, storing_logger = create_lightning_trainer(container,
checkpoint_path,
num_nodes=num_nodes,
**container.get_trainer_arguments())
num_nodes=num_nodes)
rank_info = ", ".join(f"{env}: {os.getenv(env)}"
for env in [ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK])
logging.info(f"Environment variables: {rank_info}. trainer.global_rank: {trainer.global_rank}")
Expand Down