Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
cf51128
instructions
ant0nsc Sep 22, 2021
b6201b5
Moving loading time out into a callback
ant0nsc Oct 13, 2021
7b5997a
fixing timing callback
ant0nsc Oct 13, 2021
9d55e20
docu
ant0nsc Oct 13, 2021
4254ff2
docu
ant0nsc Oct 13, 2021
0a1fc26
progress bar
ant0nsc Oct 14, 2021
f190669
docu and cleanup
ant0nsc Oct 14, 2021
f90912d
tests
ant0nsc Oct 15, 2021
9bdbd7a
test cleanup
ant0nsc Oct 18, 2021
d02ba0b
test for timers
ant0nsc Oct 19, 2021
8cbc5f1
cleanup
ant0nsc Oct 19, 2021
144698a
tests for callback
ant0nsc Oct 19, 2021
6674b70
hyperparams logging
ant0nsc Oct 19, 2021
ebf8b25
flags
ant0nsc Oct 19, 2021
fd96667
Merge branch 'antonsc/submodule_doc' into antonsc/diagnostics
ant0nsc Oct 19, 2021
b11f6dd
submodule
ant0nsc Oct 19, 2021
8ac90f2
update all usage
ant0nsc Oct 19, 2021
547626e
fix
ant0nsc Oct 20, 2021
986fba8
cleanup
ant0nsc Oct 20, 2021
a823af0
callback save and load
ant0nsc Nov 2, 2021
17564f1
find_unused
ant0nsc Nov 2, 2021
367662f
Merge remote-tracking branch 'origin/main' into antonsc/diagnostics
ant0nsc Nov 2, 2021
6064bc5
remove submodule
ant0nsc Nov 2, 2021
88ed46c
storinglogger update
ant0nsc Nov 2, 2021
857026c
head_batchsize
ant0nsc Nov 2, 2021
a71a476
using submodule
ant0nsc Nov 2, 2021
69fe247
import fix
ant0nsc Nov 2, 2021
3ef5d5d
log_on_epoch
ant0nsc Nov 3, 2021
40b6d07
cleanup of metrics
ant0nsc Nov 3, 2021
8080544
changelog
ant0nsc Nov 3, 2021
8cbafe7
removing submodule
ant0nsc Nov 3, 2021
70771d3
fix import
ant0nsc Nov 3, 2021
c964d84
changelog
ant0nsc Nov 3, 2021
a432fe2
flake fix
ant0nsc Nov 3, 2021
4ee9e8e
Merge remote-tracking branch 'origin/antonsc/diagnostics' into antons…
ant0nsc Nov 3, 2021
d605335
fixed logging
ant0nsc Nov 3, 2021
e086757
Merge remote-tracking branch 'origin/main' into antonsc/recovery2
ant0nsc Nov 12, 2021
7f75ec3
changelog
ant0nsc Nov 12, 2021
c22362e
mypy
ant0nsc Nov 12, 2021
2da000b
test fix
ant0nsc Nov 13, 2021
16f9793
Merge branch 'main' into antonsc/recovery2
ant0nsc Nov 17, 2021
588dd01
PR comments
ant0nsc Nov 17, 2021
44b8d0e
fix
ant0nsc Nov 17, 2021
ea77b47
fix
ant0nsc Nov 17, 2021
a14b301
Merge remote-tracking branch 'origin/antonsc/pathfix' into antonsc/re…
ant0nsc Nov 18, 2021
3123ed4
Merge remote-tracking branch 'origin/main' into antonsc/recovery2
ant0nsc Nov 18, 2021
5703c1c
fix
ant0nsc Nov 18, 2021
4c4f4eb
PR comments
ant0nsc Nov 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ gets uploaded to AzureML, by skipping all test folders.
`ScalarModelBase`.
- ([#554](https://github.com/microsoft/InnerEye-DeepLearning/pull/554)) Updated report in CovidModel. Set parameters
in the config to run inference on both the validation and test sets by default.
- ([#584](https://github.com/microsoft/InnerEye-DeepLearning/pull/584)) SSL models write the optimizer state for the linear head to the checkpoint now.
- ([#566](https://github.com/microsoft/InnerEye-DeepLearning/pull/566)) Update `hi-ml` dependency to `hi-ml-azure`.
- ([#572](https://github.com/microsoft/InnerEye-DeepLearning/pull/572)) Updated to new version of hi-ml package

Expand Down
3 changes: 3 additions & 0 deletions InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType,
"""
The train dataloaders
"""
# This code may be superseded in current versions of PL. Using this dictionary syntax will effectively
# use a CombinedLoader(dataloaders, mode="max_size_cycle"), similar to what we need to do explicitly for
# the validation data loader.
dataloaders = {
SSLDataModuleType.ENCODER: self.encoder_module.train_dataloader(),
SSLDataModuleType.LINEAR_HEAD: self.linear_head_module.train_dataloader()}
Expand Down
29 changes: 16 additions & 13 deletions InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ class SSLContainer(LightningContainer):
"augmentations. Ignored for CIFAR10 example")
ssl_training_dataset_name = param.ClassSelector(class_=SSLDatasetName, doc="The name of the dataset")
ssl_training_batch_size = param.Integer(
doc="Total training batch size, will be divided across the number of gpus used for training. For example: if "
"you specify ssl_training_batch_size=1600 and use 4 nodes with 4 gpus each (i.e. total of 16 GPUs), "
"the code will provide a per-gpu batch size of 100")
doc="Training batch size per GPU. The effective batch size will be the number of GPUs times this number. "
"For example, if you specify ssl_training_batch_size=100 and use 4 nodes with 4 gpus each, "
"the effective batch size will be 1600.")
ssl_training_type = param.ClassSelector(class_=SSLTrainingType, doc="Which algorithm to use for SSL training")
ssl_encoder = param.ClassSelector(class_=EncoderName, doc="Which encoder to use for SSL")
use_balanced_binary_loss_for_linear_head = param.Boolean(default=False,
Expand All @@ -92,14 +92,18 @@ class SSLContainer(LightningContainer):
"augmentations")
linear_head_dataset_name = param.ClassSelector(class_=SSLDatasetName,
doc="Name of the dataset to use for the linear head training")
linear_head_batch_size = param.Integer(default=256, doc="Batch size for linear head tuning")
linear_head_batch_size = param.Integer(default=16, doc="Batch size for linear head tuning")
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
doc="Learning rate for linear head training during "
"SSL training.")
drop_last = param.Boolean(default=True, doc="If True drops the last incomplete batch")

def setup(self) -> None:
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
if self.is_debug_model:
self.pl_limit_train_batches = 1
self.pl_limit_val_batches = 1
self.pl_find_unused_parameters = True
self.total_num_gpus = self.num_gpus_per_node() * self.num_nodes
self._load_config()
# If you're using the same data for training and linear head, allow the user to specify the dataset only
Expand Down Expand Up @@ -199,16 +203,17 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio
train_transforms, val_transforms = self._get_transforms(datamodule_args.augmentation_params,
datamodule_args.dataset_name,
is_ssl_encoder_module)
batch_size_per_gpu = datamodule_args.batch_size // self.total_num_gpus if self.total_num_gpus > 0 else \
datamodule_args.batch_size
logging.info(f"Batch size per gpu: {batch_size_per_gpu}")
batch_multiplier = self.total_num_gpus if self.total_num_gpus > 0 else 1
effective_batch_size = datamodule_args.batch_size * batch_multiplier
logging.info(f"Batch size per GPU: {datamodule_args.batch_size}")
logging.info(f"Effective batch size on {batch_multiplier} GPUs: {effective_batch_size}")
dm = InnerEyeVisionDataModule(dataset_cls=self._SSLDataClassMappings[datamodule_args.dataset_name],
return_index=not is_ssl_encoder_module, # index is only needed for linear head
train_transforms=train_transforms,
val_split=0.1,
val_transforms=val_transforms,
data_dir=str(datamodule_args.dataset_path),
batch_size=batch_size_per_gpu,
batch_size=datamodule_args.batch_size,
num_workers=self.num_workers,
seed=self.random_seed,
drop_last=self.drop_last)
Expand All @@ -226,9 +231,9 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],
:param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation
pipeline to return.
:param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is
applied on and it applies the training transformations also at validation time. Note that if your transformation
does not contain any randomness, the pipeline will return two identical copies. If False, it will return only one
transformation.
applied on and it applies the training transformations also at validation time. Note that if your
transformation does not contain any randomness, the pipeline will return two identical copies. If False, it
will return only one transformation.
:return: training transformation pipeline and validation transformation pipeline.
"""
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
Expand Down Expand Up @@ -269,6 +274,4 @@ def get_trainer_arguments(self) -> Dict[str, Any]:
drop_p=0.2,
learning_rate=self.learning_rate_linear_head_during_ssl_training)
trainer_kwargs: Dict[str, Any] = {"callbacks": self.online_eval}
if self.is_debug_model:
trainer_kwargs.update({"limit_train_batches": 1, "limit_val_batches": 1})
return trainer_kwargs
7 changes: 3 additions & 4 deletions InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def get_data_module(self) -> InnerEyeDataModuleTypes:
return self.data_module

def get_trainer_arguments(self) -> Dict[str, Any]:
trained_kwargs = {}
if self.is_debug_model:
trained_kwargs.update({"limit_train_batches": 1, "limit_val_batches": 1})
return trained_kwargs
# This class inherits from SSLContainer, where the get_trainer_arguments adds the online evaluator callback.
# We don't need that for the classifier, hence need to return an empty set of trainer arguments.
return {}
20 changes: 12 additions & 8 deletions InnerEye/ML/SSL/lightning_modules/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from health_ml.utils import log_learning_rate, log_on_epoch
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pytorch_lightning import Trainer
from torch import Tensor as T
from torch.optim import Adam

from InnerEye.ML.SSL.lightning_modules.byol.byol_models import SiameseArm
from InnerEye.ML.SSL.lightning_modules.byol.byol_moving_average import ByolMovingAverageWeightUpdate
from InnerEye.ML.SSL.utils import SSLDataModuleType
from pytorch_lightning import Trainer

SingleBatchType = Tuple[List, T]
BatchType = Union[Dict[SSLDataModuleType, SingleBatchType], SingleBatchType]
Expand Down Expand Up @@ -98,14 +99,15 @@ def shared_step(self, batch: BatchType, batch_idx: int) -> T:

return loss

def training_step(self, batch: BatchType, batch_idx: int, **kwargs: Any) -> T: # type: ignore
def training_step(self, batch: BatchType, batch_idx: int, **kwargs: Any) -> torch.Tensor: # type: ignore
loss = self.shared_step(batch, batch_idx)
self.log_dict({'byol/train/loss': loss, 'byol/tau': self.weight_callback.current_tau})
log_on_epoch(self, metrics={'byol/train/loss': loss, 'byol/tau': self.weight_callback.current_tau})
log_learning_rate(self, name="byol/learning_rate")
return loss

def validation_step(self, batch: BatchType, batch_idx: int, **kwargs: Any) -> T: # type: ignore
loss = self.shared_step(batch, batch_idx)
self.log_dict({'byol/val/loss': loss})
log_on_epoch(self, 'byol/val/loss', loss)
return loss

def setup(self, *args: Any, **kwargs: Any) -> None:
Expand All @@ -116,9 +118,12 @@ def configure_optimizers(self) -> Any:
# exclude certain parameters
parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(),
weight_decay=self.hparams.weight_decay) # type: ignore
optimizer = Adam(parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
scheduler = LinearWarmupCosineAnnealingLR(
optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs) # type: ignore
optimizer = Adam(parameters,
lr=self.hparams.learning_rate, # type: ignore
weight_decay=self.hparams.weight_decay) # type: ignore
scheduler = LinearWarmupCosineAnnealingLR(optimizer,
warmup_epochs=self.hparams.warmup_epochs, # type: ignore
max_epochs=self.hparams.max_epochs) # type: ignore
return [optimizer], [scheduler]

def exclude_from_wt_decay(self,
Expand All @@ -144,4 +149,3 @@ def exclude_from_wt_decay(self,
{'params': params, 'weight_decay': weight_decay},
{'params': excluded_params, 'weight_decay': 0.}
]

20 changes: 14 additions & 6 deletions InnerEye/ML/SSL/lightning_modules/simclr_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from health_ml.utils import log_learning_rate, log_on_epoch
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
from torch import Tensor as T

from InnerEye.ML.SSL.encoders import SSLEncoder
from InnerEye.ML.SSL.utils import SSLDataModuleType

SingleBatchType = Tuple[List, T]
SingleBatchType = Tuple[List, torch.Tensor]
BatchType = Union[Dict[SSLDataModuleType, SingleBatchType], SingleBatchType]


Expand Down Expand Up @@ -57,7 +57,18 @@ def __init__(self, encoder_name: str, dataset_name: str, use_7x7_first_conv_in_r
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)

def shared_step(self, batch: BatchType) -> T:
def training_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor:
loss = self.shared_step(batch)
log_on_epoch(self, "simclr/train/loss", loss, sync_dist=False)
log_learning_rate(self, name="simclr/learning_rate")
return loss

def validation_step(self, batch: BatchType, batch_idx: int) -> torch.Tensor: # type: ignore
loss = self.shared_step(batch)
log_on_epoch(self, "simclr/val/loss", loss, sync_dist=False)
return loss

def shared_step(self, batch: BatchType) -> torch.Tensor:
batch = batch[SSLDataModuleType.ENCODER] if isinstance(batch, dict) else batch

(img1, img2), y = batch
Expand All @@ -72,6 +83,3 @@ def shared_step(self, batch: BatchType) -> T:
loss = self.nt_xent_loss(z1, z2, self.temperature)

return loss



9 changes: 5 additions & 4 deletions InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torchmetrics import Metric
from pl_bolts.models.self_supervised import SSLEvaluator
from health_ml.utils import log_on_epoch
from torch.nn import functional as F

from InnerEye.ML.SSL.encoders import get_encoder_output_dim
Expand Down Expand Up @@ -79,16 +80,16 @@ def shared_step(self, batch: Any, is_training: bool) -> Any:

def training_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) -> Any: # type: ignore
loss = self.shared_step(batch, True)
self.log("train/loss", loss, on_step=False, on_epoch=True)
log_on_epoch(self, "train/loss", loss)
for metric in self.train_metrics:
self.log(f"train/{metric.name}", metric, on_epoch=True, on_step=False)
log_on_epoch(self, f"train/{metric.name}", metric)
return loss

def validation_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) -> None: # type: ignore
loss = self.shared_step(batch, is_training=False)
self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=True)
log_on_epoch(self, 'val/loss', loss)
for metric in self.val_metrics:
self.log(f"val/{metric.name}", metric, on_epoch=True, on_step=False)
log_on_epoch(self, f"val/{metric.name}", metric)

def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]:
"""
Expand Down
Loading