1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import subprocess
15+ from contextlib import contextmanager
1616from copy import copy , deepcopy
1717
1818import numpy as np
@@ -656,6 +656,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
656656 # checks if backward or backward + optimizer step (via closure)
657657 accumulation_done = self ._accumulated_batches_reached ()
658658 is_final_batch = self ._num_training_batches_reached ()
659+ should_accumulate = not (accumulation_done or is_final_batch )
659660
660661 # lightning module hook
661662 splits = self .tbptt_split_batch (batch )
@@ -676,13 +677,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
676677 model = self .trainer .get_model ()
677678 model .toggle_optimizer (optimizer , opt_idx )
678679
679- if not ( accumulation_done or is_final_batch ) :
680+ if should_accumulate :
680681 # For gradient accumulation
681682
682683 # -------------------
683684 # calculate loss (train step + train step end)
684685 # -------------------
685- self .training_step_and_backward (split_batch , batch_idx , opt_idx , optimizer , self .trainer .hiddens )
686+
687+ # perform dpp sync only when performing optimizer_step
688+ with self .block_ddp_sync_behaviour ():
689+ self .training_step_and_backward (split_batch , batch_idx , opt_idx , optimizer , self .trainer .hiddens )
690+
686691 batch_outputs = self ._process_closure_result (
687692 batch_callback_metrics = batch_callback_metrics ,
688693 batch_log_metrics = batch_log_metrics ,
@@ -696,7 +701,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
696701 # gradient update with accumulated gradients
697702
698703 else :
699-
700704 if self .automatic_optimization :
701705
702706 def train_step_and_backward_closure ():
@@ -761,6 +765,13 @@ def train_step_and_backward_closure():
761765 )
762766 return result
763767
768+ @contextmanager
769+ def block_ddp_sync_behaviour (self ):
770+ if isinstance (self .trainer .model , torch .nn .parallel .DistributedDataParallel ):
771+ yield from self .trainer .model .no_sync ()
772+ else :
773+ yield
774+
764775 def _process_closure_result (
765776 self , batch_callback_metrics : list , batch_log_metrics : list , batch_outputs : list , opt_idx : int
766777 ) -> list :
0 commit comments