Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
from contextlib import contextmanager
from copy import copy, deepcopy

import numpy as np
Expand Down Expand Up @@ -655,6 +655,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# checks if backward or backward + optimizer step (via closure)
accumulation_done = self._accumulated_batches_reached()
is_final_batch = self._num_training_batches_reached()
should_accumulate = not (accumulation_done or is_final_batch)

# lightning module hook
splits = self.tbptt_split_batch(batch)
Expand All @@ -675,13 +676,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
model = self.trainer.get_model()
model.toggle_optimizer(optimizer, opt_idx)

if not (accumulation_done or is_final_batch):
if should_accumulate:
# For gradient accumulation

# -------------------
# calculate loss (train step + train step end)
# -------------------
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)

# perform dpp sync only when performing optimizer_step
with self.block_ddp_sync_behaviour():
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)

Comment on lines +686 to +689
Copy link
Contributor

@ananthsub ananthsub Oct 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need block_ddp_sync_behaviour ?

Suggested change
# perform dpp sync only when performing optimizer_step
with self.block_ddp_sync_behaviour():
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)
# perform dpp sync only when performing optimizer_step
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
with self.trainer.model.no_sync():
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)
else:
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make it more explicit for our readers :)

Copy link
Member Author

@justusschock justusschock Oct 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ananthsub @tchaton requested this to make it more readable and hide the conditions for that in the context manager

Copy link
Contributor

@tchaton tchaton Oct 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ananthsub
The training and evaluation loop should look absolutely perfect and simple to understand.
As a new coder, I should recognise my training loop. As a new coder, I have no knowledge about ddp and your suggested change would have confused me :)

One great example of this PR is: not (accumulation_done or is_final_batch) -> should_accumulate. It is pretty simple, but make a clear statement about what is happening. We should try to enforce those pattern as much as possible :)

I hope it makes sense :)

batch_outputs = self._process_closure_result(
batch_callback_metrics=batch_callback_metrics,
batch_log_metrics=batch_log_metrics,
Expand All @@ -695,7 +700,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# gradient update with accumulated gradients

else:

if self.automatic_optimization:

def train_step_and_backward_closure():
Expand Down Expand Up @@ -760,6 +764,13 @@ def train_step_and_backward_closure():
)
return result

@contextmanager
def block_ddp_sync_behaviour(self):
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
yield from self.trainer.model.no_sync()
else:
yield

def _process_closure_result(
self, batch_callback_metrics: list, batch_log_metrics: list, batch_outputs: list, opt_idx: int
) -> list:
Expand Down