@@ -383,12 +383,18 @@ def barrier(self, name: Optional[str] = None) -> None:
383383 self .training_type_plugin .barrier (name = name )
384384
385385 def broadcast (self , obj : object , src : int = 0 ) -> object :
386- """Broadcasts an object to all processes"""
386+ """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed.
387+
388+ Args:
389+ obj: Object to broadcast to all process, usually a tensor or collection of tensors.
390+ src: The source rank of which the object will be broadcast from
391+ """
387392 return self .training_type_plugin .broadcast (obj , src )
388393
389394 def all_gather (self , tensor : Union [torch .Tensor ], group : Optional [Any ] = None , sync_grads : bool = False ):
390395 """
391- Function to gather a tensor from several distributed processes
396+ Function to gather a tensor from several distributed processes.
397+
392398 Args:
393399 tensor: tensor of shape (batch, ...)
394400 group: the process group to gather results from. Defaults to all processes (world)
@@ -409,8 +415,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
409415 @property
410416 def results (self ) -> Any :
411417 """
412- The results of the last training/testing run will be cached here .
418+ The results of the last training/testing run will be cached within the training type plugin .
413419 In distributed training, we make sure to transfer the results to the appropriate master process.
414420 """
415- # TODO: improve these docs
416421 return self .training_type_plugin .results
0 commit comments