1717import copy
1818import inspect
1919import logging
20+ import numbers
2021import os
2122import tempfile
2223import types
4243from pytorch_lightning .utilities .apply_func import apply_to_collection , convert_to_tensors
4344from pytorch_lightning .utilities .cloud_io import get_filesystem
4445from pytorch_lightning .utilities .device_dtype_mixin import DeviceDtypeModuleMixin
46+ from pytorch_lightning .utilities .distributed import sync_ddp_if_available , tpu_distributed
4547from pytorch_lightning .utilities .exceptions import MisconfigurationException
4648from pytorch_lightning .utilities .parsing import AttributeDict , collect_init_args , save_hyperparameters
4749from pytorch_lightning .utilities .signature_utils import is_param_in_hook_signature
48- from pytorch_lightning .utilities .types import EPOCH_OUTPUT , STEP_OUTPUT
50+ from pytorch_lightning .utilities .types import _METRIC , EPOCH_OUTPUT , STEP_OUTPUT
4951from pytorch_lightning .utilities .warnings import WarningCache
5052
5153warning_cache = WarningCache ()
@@ -336,6 +338,15 @@ def log(
336338 f"Logged key: { name } should not contain information about dataloader_idx."
337339 )
338340
341+ value = self .__sync (
342+ value ,
343+ sync_fn = self .trainer .training_type_plugin .reduce ,
344+ sync_dist = sync_dist ,
345+ sync_dist_op = sync_dist_op ,
346+ sync_dist_group = sync_dist_group ,
347+ device = self .device ,
348+ )
349+
339350 self ._results .log (
340351 name ,
341352 value ,
@@ -345,12 +356,7 @@ def log(
345356 on_epoch = on_epoch ,
346357 reduce_fx = reduce_fx ,
347358 enable_graph = enable_graph ,
348- sync_dist = sync_dist ,
349- sync_dist_op = sync_dist_op ,
350- sync_dist_group = sync_dist_group ,
351- sync_fn = self .trainer .training_type_plugin .reduce ,
352359 dataloader_idx = (self ._current_dataloader_idx if add_dataloader_idx else None ),
353- device = self .device ,
354360 )
355361
356362 def log_dict (
@@ -410,6 +416,31 @@ def log_dict(
410416 add_dataloader_idx = add_dataloader_idx
411417 )
412418
419+ @staticmethod
420+ def __sync (
421+ value : _METRIC ,
422+ sync_fn : Optional [Callable ] = None ,
423+ sync_dist : bool = False ,
424+ sync_dist_op : Union [Any , str ] = 'mean' ,
425+ sync_dist_group : Optional [Any ] = None ,
426+ device : torch .device = None ,
427+ ) -> _METRIC :
428+ """Sync across workers when using distributed training"""
429+ if not isinstance (value , (torch .Tensor , numbers .Number )):
430+ return value
431+
432+ sync_fn = sync_fn or sync_ddp_if_available
433+ dist_available = torch .distributed .is_available () and torch .distributed .is_initialized () or tpu_distributed ()
434+ if not sync_dist or not dist_available :
435+ return value
436+
437+ # TODO: Find a way to make the reduction only once, so we don't need to clone.
438+ if isinstance (value , torch .Tensor ):
439+ value = value .clone ()
440+ else :
441+ value = torch .tensor (value , device = device , dtype = torch .float )
442+ return sync_fn (value , group = sync_dist_group , reduce_op = sync_dist_op )
443+
413444 def write_prediction (
414445 self , name : str , value : Union [torch .Tensor , List [torch .Tensor ]], filename : str = 'predictions.pt'
415446 ):
0 commit comments