@@ -528,7 +528,7 @@ def __check_not_nested(value: dict, name: str) -> dict:
528528 def __check_allowed (v : Any , name : str , value : Any ) -> None :
529529 raise ValueError (f"`self.log({ name } , { value } )` was called, but `{ type (v ).__name__ } ` values cannot be logged" )
530530
531- def __to_tensor (self , value : numbers .Number ) -> torch . Tensor :
531+ def __to_tensor (self , value : numbers .Number ) -> Tensor :
532532 return torch .tensor (value , device = self .device )
533533
534534 def log_grad_norm (self , grad_norm_dict : Dict [str , float ]) -> None :
@@ -547,9 +547,7 @@ def log_grad_norm(self, grad_norm_dict):
547547 """
548548 self .log_dict (grad_norm_dict , on_step = True , on_epoch = True , prog_bar = False , logger = True )
549549
550- def all_gather (
551- self , data : Union [torch .Tensor , Dict , List , Tuple ], group : Optional [Any ] = None , sync_grads : bool = False
552- ):
550+ def all_gather (self , data : Union [Tensor , Dict , List , Tuple ], group : Optional [Any ] = None , sync_grads : bool = False ):
553551 r"""
554552 Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ``all_gather`` operation
555553 accelerator agnostic. ``all_gather`` is a function provided by accelerators to gather a tensor from several
@@ -567,7 +565,7 @@ def all_gather(
567565 group = group if group is not None else torch .distributed .group .WORLD
568566 all_gather = self .trainer .strategy .all_gather
569567 data = convert_to_tensors (data , device = self .device )
570- return apply_to_collection (data , torch . Tensor , all_gather , group = group , sync_grads = sync_grads )
568+ return apply_to_collection (data , Tensor , all_gather , group = group , sync_grads = sync_grads )
571569
572570 def forward (self , * args , ** kwargs ) -> Any :
573571 r"""
@@ -1701,15 +1699,15 @@ def tbptt_split_batch(self, batch, split_size):
17011699 if :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0.
17021700 Each returned batch split is passed separately to :meth:`training_step`.
17031701 """
1704- time_dims = [len (x [0 ]) for x in batch if isinstance (x , (torch . Tensor , collections .Sequence ))]
1702+ time_dims = [len (x [0 ]) for x in batch if isinstance (x , (Tensor , collections .Sequence ))]
17051703 assert len (time_dims ) >= 1 , "Unable to determine batch time dimension"
17061704 assert all (x == time_dims [0 ] for x in time_dims ), "Batch time dimension length is ambiguous"
17071705
17081706 splits = []
17091707 for t in range (0 , time_dims [0 ], split_size ):
17101708 batch_split = []
17111709 for i , x in enumerate (batch ):
1712- if isinstance (x , torch . Tensor ):
1710+ if isinstance (x , Tensor ):
17131711 split_x = x [:, t : t + split_size ]
17141712 elif isinstance (x , collections .Sequence ):
17151713 split_x = [None ] * len (x )
0 commit comments