-
Couldn't load subscription status.
- Fork 147
Fix recovery of SSL training, scale SSL training to multiple nodes #565
Conversation
This reverts commit e03f0a9.
| @@ -1,8 +1,7 @@ | |||
| # ------------------------------------------------------------------------------------------ | |||
| # Copyright (c) Microsoft Corporation. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to attach ssl training run results for future reference - before and after this manual optimisation change. (both for SimCLR and BYOL)
| super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10, | ||
| linear_head_dataset_name=SSLDatasetName.CIFAR10, | ||
| ssl_training_batch_size=512, | ||
| ssl_training_batch_size=64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weren't we training these models in machines with 4 gpus? Should this be reduced to 128?
| recovery_checkpoint_save_interval=200, | ||
| num_epochs=1000, | ||
| ssl_training_batch_size=1200, | ||
| ssl_training_batch_size=75, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same in here.
| param.Boolean(default=False, | ||
| doc="Controls the PyTorch Lightning flag 'find_unused_parameters' for the DDP plugin. " | ||
| "Setting it to True comes with a performance hit.") | ||
| pl_limit_train_batches: Optional[int] = \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if the user specifies zero? Would it skip train/val automatically, is it tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0 is valid. We are passing this value straight through to PL, and it does whatever it would do then - yes, it would skip training/validation.
|
superseded by #584 |
When SSL training gets interrupted on low-priority nodes, there are presently weird glitches on the metrics for the linear head. We suspect that those come from the fact that the optimizer for the linear head is not saved to the checkpoint, and hence has to re-learn all of its statistics.
This PR adds the linear head optimizer such that it is accessible to PL, and will be included in the checkpoint.
In addition, the semantics of
batch_sizein SSL training is changed: Previously it was the effective batch size, taking multiple nodes into account. This meant that the code was effectively hardcoding 16 GPUS. New behaviour: Batch size is now the batch size on a single GPU. As a consequence, we can scale to any number of GPUs without code changes.There are also new flags
pl_limit_train_batchesandpl_limit_val_batchesto speed up training, by reducing the number of batches processed.