Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit 283fec8

Browse files
authored
Merge branch 'main' into 690-ignore-sphinx-build-folders
2 parents 6e43d69 + 6791dce commit 283fec8

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ created.
1313

1414
### Added
1515

16+
- ([#667](https://github.com/microsoft/InnerEye-DeepLearning/pull/667)) Automatically and linearly scale the learning rate of the SSL encoder to the number of GPUs.
1617
- ([#689](https://github.com/microsoft/InnerEye-DeepLearning/pull/689)) Show default argument values in help message.
1718
- ([#671](https://github.com/microsoft/InnerEye-DeepLearning/pull/671)) Remove sequence models and unused variables. Simplify README.
1819
- ([#693](https://github.com/microsoft/InnerEye-DeepLearning/pull/693)) Improve instructions for HelloWorld model in AzureML.

InnerEye/ML/SSL/lightning_containers/ssl_container.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,23 @@ def create_model(self) -> LightningModule:
147147
# For small images like CIFAR, if using a resnet encoder, switch the first conv layer to a 3x3 kernel instead
148148
# of a 7x7 conv layer.
149149
use_7x7_first_conv_in_resnet = False if self.ssl_training_dataset_name.value.startswith("CIFAR") else True
150+
151+
# Rescale the learning rate linearly according to the number of available GPUs, as seen in: https://arxiv.org/abs/1706.02677,
152+
# to avoid a drop in performance
153+
gpus_per_node = self.num_gpus_per_node()
154+
num_of_total_gpus = self.num_nodes * gpus_per_node
155+
if num_of_total_gpus > 1:
156+
l_rate: float = self.l_rate * num_of_total_gpus
157+
logging.info(f"We found {num_of_total_gpus} GPUs, SSL encoder learning rate has been adjusted from {self.l_rate} to {l_rate}")
158+
self.l_rate = l_rate
159+
150160
if self.ssl_training_type == SSLTrainingType.SimCLR:
151161
model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value,
152162
dataset_name=self.ssl_training_dataset_name.value,
153163
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
154164
num_samples=self.data_module.num_train_samples,
155165
batch_size=self.data_module.batch_size,
156-
gpus=self.num_gpus_per_node(),
166+
gpus=gpus_per_node,
157167
num_nodes=self.num_nodes,
158168
learning_rate=self.l_rate,
159169
max_epochs=self.num_epochs)

0 commit comments

Comments
 (0)