Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Conversation

@ant0nsc
Copy link
Contributor

@ant0nsc ant0nsc commented Sep 20, 2021

When SSL training gets interrupted on low-priority nodes, there are presently weird glitches on the metrics for the linear head. We suspect that those come from the fact that the optimizer for the linear head is not saved to the checkpoint, and hence has to re-learn all of its statistics.
This PR adds the linear head optimizer such that it is accessible to PL, and will be included in the checkpoint.

In addition, the semantics of batch_size in SSL training is changed: Previously it was the effective batch size, taking multiple nodes into account. This meant that the code was effectively hardcoding 16 GPUS. New behaviour: Batch size is now the batch size on a single GPU. As a consequence, we can scale to any number of GPUs without code changes.

There are also new flags pl_limit_train_batches and pl_limit_val_batches to speed up training, by reducing the number of batches processed.

@ant0nsc ant0nsc changed the title Fix recovery of SSL training Fix recovery of SSL training, scale SSL training to multiple nodes Oct 6, 2021
@ant0nsc ant0nsc requested a review from Shruthi42 October 6, 2021 14:41
@@ -1,8 +1,7 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
Copy link
Contributor

@ozan-oktay ozan-oktay Oct 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to attach ssl training run results for future reference - before and after this manual optimisation change. (both for SimCLR and BYOL)

super().__init__(ssl_training_dataset_name=SSLDatasetName.CIFAR10,
linear_head_dataset_name=SSLDatasetName.CIFAR10,
ssl_training_batch_size=512,
ssl_training_batch_size=64,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weren't we training these models in machines with 4 gpus? Should this be reduced to 128?

recovery_checkpoint_save_interval=200,
num_epochs=1000,
ssl_training_batch_size=1200,
ssl_training_batch_size=75,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same in here.

param.Boolean(default=False,
doc="Controls the PyTorch Lightning flag 'find_unused_parameters' for the DDP plugin. "
"Setting it to True comes with a performance hit.")
pl_limit_train_batches: Optional[int] = \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if the user specifies zero? Would it skip train/val automatically, is it tested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 is valid. We are passing this value straight through to PL, and it does whatever it would do then - yes, it would skip training/validation.

@ant0nsc
Copy link
Contributor Author

ant0nsc commented Nov 15, 2021

superseded by #584

@ant0nsc ant0nsc closed this Nov 15, 2021
@ant0nsc ant0nsc deleted the antonsc/recovery branch January 31, 2022 14:40
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants