-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Skips DDP parameter sync #4301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skips DDP parameter sync #4301
Changes from all commits
5b03f1d
43e6a14
eccb8e3
b0ab9a7
30eed75
338a9a7
59a5e5a
0149d5c
36d5c88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import subprocess | ||||||||||||||||||||
| from contextlib import contextmanager | ||||||||||||||||||||
| from copy import copy, deepcopy | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||
|
|
@@ -655,6 +655,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): | |||||||||||||||||||
| # checks if backward or backward + optimizer step (via closure) | ||||||||||||||||||||
| accumulation_done = self._accumulated_batches_reached() | ||||||||||||||||||||
| is_final_batch = self._num_training_batches_reached() | ||||||||||||||||||||
| should_accumulate = not (accumulation_done or is_final_batch) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # lightning module hook | ||||||||||||||||||||
| splits = self.tbptt_split_batch(batch) | ||||||||||||||||||||
|
|
@@ -675,13 +676,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): | |||||||||||||||||||
| model = self.trainer.get_model() | ||||||||||||||||||||
| model.toggle_optimizer(optimizer, opt_idx) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if not (accumulation_done or is_final_batch): | ||||||||||||||||||||
| if should_accumulate: | ||||||||||||||||||||
| # For gradient accumulation | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # ------------------- | ||||||||||||||||||||
| # calculate loss (train step + train step end) | ||||||||||||||||||||
| # ------------------- | ||||||||||||||||||||
| self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # perform dpp sync only when performing optimizer_step | ||||||||||||||||||||
| with self.block_ddp_sync_behaviour(): | ||||||||||||||||||||
| self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
Comment on lines
+686
to
+689
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to make it more explicit for our readers :)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ananthsub @tchaton requested this to make it more readable and hide the conditions for that in the context manager
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ananthsub One great example of this PR is: not (accumulation_done or is_final_batch) -> should_accumulate. It is pretty simple, but make a clear statement about what is happening. We should try to enforce those pattern as much as possible :) I hope it makes sense :) |
||||||||||||||||||||
| batch_outputs = self._process_closure_result( | ||||||||||||||||||||
| batch_callback_metrics=batch_callback_metrics, | ||||||||||||||||||||
| batch_log_metrics=batch_log_metrics, | ||||||||||||||||||||
|
|
@@ -695,7 +700,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): | |||||||||||||||||||
| # gradient update with accumulated gradients | ||||||||||||||||||||
|
|
||||||||||||||||||||
| else: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if self.automatic_optimization: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def train_step_and_backward_closure(): | ||||||||||||||||||||
|
|
@@ -760,6 +764,13 @@ def train_step_and_backward_closure(): | |||||||||||||||||||
| ) | ||||||||||||||||||||
| return result | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @contextmanager | ||||||||||||||||||||
| def block_ddp_sync_behaviour(self): | ||||||||||||||||||||
| if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel): | ||||||||||||||||||||
| yield from self.trainer.model.no_sync() | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| yield | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _process_closure_result( | ||||||||||||||||||||
| self, batch_callback_metrics: list, batch_log_metrics: list, batch_outputs: list, opt_idx: int | ||||||||||||||||||||
| ) -> list: | ||||||||||||||||||||
|
|
||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.