Skip to content

Commit 9e6fde2

Browse files
committed
unify torch.Tensor >> Tensor
1 parent ed07339 commit 9e6fde2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+173
-154
lines changed

src/pytorch_lightning/callbacks/callback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from typing import Any, Dict, List, Optional, Type
2020

21-
import torch
21+
from torch import Tensor
2222
from torch.optim import Optimizer
2323

2424
import pytorch_lightning as pl
@@ -342,7 +342,7 @@ def on_load_checkpoint(
342342
checkpoint dictionary instead of only the callback state from the checkpoint.
343343
"""
344344

345-
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None:
345+
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None:
346346
"""Called before ``loss.backward()``."""
347347

348348
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

src/pytorch_lightning/callbacks/early_stopping.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import numpy as np
2525
import torch
26+
from torch import Tensor
2627

2728
import pytorch_lightning as pl
2829
from pytorch_lightning.callbacks.callback import Callback
@@ -203,7 +204,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
203204
if reason and self.verbose:
204205
self._log_info(trainer, reason)
205206

206-
def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, Optional[str]]:
207+
def _evaluate_stopping_criteria(self, current: Tensor) -> Tuple[bool, Optional[str]]:
207208
should_stop = False
208209
reason = None
209210
if self.check_finite and not torch.isfinite(current):
@@ -242,7 +243,7 @@ def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, Opti
242243

243244
return should_stop, reason
244245

245-
def _improvement_message(self, current: torch.Tensor) -> str:
246+
def _improvement_message(self, current: Tensor) -> str:
246247
"""Formats a log message that informs the user about an improvement in the monitored score."""
247248
if torch.isfinite(self.best_score):
248249
msg = (

src/pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import numpy as np
3232
import torch
3333
import yaml
34+
from torch import Tensor
3435

3536
import pytorch_lightning as pl
3637
from pytorch_lightning.callbacks.callback import Callback
@@ -477,7 +478,7 @@ def __init_triggers(
477478
def every_n_epochs(self) -> Optional[int]:
478479
return self._every_n_epochs
479480

480-
def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Tensor] = None) -> bool:
481+
def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = None) -> bool:
481482
if current is None:
482483
return False
483484

@@ -628,11 +629,9 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
628629
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
629630
# or does not exist we overwrite it as it's likely an error
630631
epoch = monitor_candidates.get("epoch")
631-
monitor_candidates["epoch"] = (
632-
epoch.int() if isinstance(epoch, torch.Tensor) else torch.tensor(trainer.current_epoch)
633-
)
632+
monitor_candidates["epoch"] = epoch.int() if isinstance(epoch, Tensor) else torch.tensor(trainer.current_epoch)
634633
step = monitor_candidates.get("step")
635-
monitor_candidates["step"] = step.int() if isinstance(step, torch.Tensor) else torch.tensor(trainer.global_step)
634+
monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step)
636635
return monitor_candidates
637636

638637
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
@@ -670,7 +669,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
670669
trainer.strategy.remove_checkpoint(previous)
671670

672671
def _update_best_and_save(
673-
self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
672+
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
674673
) -> None:
675674
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
676675

@@ -680,7 +679,7 @@ def _update_best_and_save(
680679
self.best_k_models.pop(del_filepath)
681680

682681
# do not save nan, replace with +/- inf
683-
if isinstance(current, torch.Tensor) and torch.isnan(current):
682+
if isinstance(current, Tensor) and torch.isnan(current):
684683
current = torch.tensor(float("inf" if self.mode == "min" else "-inf"), device=current.device)
685684

686685
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath)

src/pytorch_lightning/callbacks/pruning.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
from functools import partial
2222
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
2323

24-
import torch
2524
import torch.nn.utils.prune as pytorch_prune
26-
from torch import nn
25+
from torch import nn, Tensor
2726
from typing_extensions import TypedDict
2827

2928
import pytorch_lightning as pl
@@ -275,7 +274,7 @@ def make_pruning_permanent(self, module: nn.Module) -> None:
275274
def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
276275
dst = getattr(new, name)
277276
src = getattr(old, name)
278-
if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor):
277+
if dst is None or src is None or not isinstance(dst, Tensor) or not isinstance(src, Tensor):
279278
return
280279
dst.data = src.data.to(dst.device)
281280

@@ -418,11 +417,11 @@ def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> D
418417
# make weights permanent
419418
state_dict[tensor_name] = mask.to(dtype=orig.dtype) * orig
420419

421-
def move_to_cpu(tensor: torch.Tensor) -> torch.Tensor:
420+
def move_to_cpu(tensor: Tensor) -> Tensor:
422421
# each tensor and move them on cpu
423422
return tensor.cpu()
424423

425-
return apply_to_collection(state_dict, torch.Tensor, move_to_cpu)
424+
return apply_to_collection(state_dict, Tensor, move_to_cpu)
426425

427426
def on_save_checkpoint(
428427
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]

src/pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Callable, List, Optional, Union
2020

2121
import torch
22-
from torch import nn
22+
from torch import FloatTensor, nn, Tensor
2323
from torch.optim.swa_utils import SWALR
2424

2525
import pytorch_lightning as pl
@@ -28,7 +28,7 @@
2828
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
2929
from pytorch_lightning.utilities.types import LRSchedulerConfig
3030

31-
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
31+
_AVG_FN = Callable[[Tensor, Tensor, torch.LongTensor], FloatTensor]
3232

3333

3434
class StochasticWeightAveraging(Callback):
@@ -269,7 +269,7 @@ def update_parameters(
269269

270270
@staticmethod
271271
def avg_fn(
272-
averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
273-
) -> torch.FloatTensor:
272+
averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: torch.LongTensor
273+
) -> FloatTensor:
274274
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
275275
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)

src/pytorch_lightning/core/hooks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Dict, List, Optional
1717

1818
import torch
19+
from torch import Tensor
1920
from torch.optim.optimizer import Optimizer
2021

2122
from pytorch_lightning.utilities import move_data_to_device
@@ -254,7 +255,7 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None:
254255
optimizer: The optimizer for which grads should be zeroed.
255256
"""
256257

257-
def on_before_backward(self, loss: torch.Tensor) -> None:
258+
def on_before_backward(self, loss: Tensor) -> None:
258259
"""Called before ``loss.backward()``.
259260
260261
Args:

src/pytorch_lightning/core/mixins/device_dtype_mixin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ def to(self, *args: Any, **kwargs: Any) -> Self:
8989
Module: self
9090
9191
Example::
92+
>>> from torch import Tensor
9293
>>> class ExampleModule(DeviceDtypeModuleMixin):
93-
... def __init__(self, weight: torch.Tensor):
94+
... def __init__(self, weight: Tensor):
9495
... super().__init__()
9596
... self.register_buffer('weight', weight)
9697
>>> _ = torch.manual_seed(0)

src/pytorch_lightning/core/module.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def __check_not_nested(value: dict, name: str) -> dict:
528528
def __check_allowed(v: Any, name: str, value: Any) -> None:
529529
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
530530

531-
def __to_tensor(self, value: numbers.Number) -> torch.Tensor:
531+
def __to_tensor(self, value: numbers.Number) -> Tensor:
532532
return torch.tensor(value, device=self.device)
533533

534534
def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
@@ -547,9 +547,7 @@ def log_grad_norm(self, grad_norm_dict):
547547
"""
548548
self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=False, logger=True)
549549

550-
def all_gather(
551-
self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
552-
):
550+
def all_gather(self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False):
553551
r"""
554552
Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ``all_gather`` operation
555553
accelerator agnostic. ``all_gather`` is a function provided by accelerators to gather a tensor from several
@@ -567,7 +565,7 @@ def all_gather(
567565
group = group if group is not None else torch.distributed.group.WORLD
568566
all_gather = self.trainer.strategy.all_gather
569567
data = convert_to_tensors(data, device=self.device)
570-
return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads)
568+
return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads)
571569

572570
def forward(self, *args, **kwargs) -> Any:
573571
r"""
@@ -1701,15 +1699,15 @@ def tbptt_split_batch(self, batch, split_size):
17011699
if :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0.
17021700
Each returned batch split is passed separately to :meth:`training_step`.
17031701
"""
1704-
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
1702+
time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.Sequence))]
17051703
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
17061704
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
17071705

17081706
splits = []
17091707
for t in range(0, time_dims[0], split_size):
17101708
batch_split = []
17111709
for i, x in enumerate(batch):
1712-
if isinstance(x, torch.Tensor):
1710+
if isinstance(x, Tensor):
17131711
split_x = x[:, t : t + split_size]
17141712
elif isinstance(x, collections.Sequence):
17151713
split_x = [None] * len(x)

src/pytorch_lightning/lite/lite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ def barrier(self, name: Optional[str] = None) -> None:
342342
self._strategy.barrier(name=name)
343343

344344
def all_gather(
345-
self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
346-
) -> Union[torch.Tensor, Dict, List, Tuple]:
345+
self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
346+
) -> Union[Tensor, Dict, List, Tuple]:
347347
r"""
348348
Gather tensors or collections of tensors from multiple processes.
349349
@@ -358,7 +358,7 @@ def all_gather(
358358
"""
359359
group = group if group is not None else torch.distributed.group.WORLD
360360
data = convert_to_tensors(data, device=self.device)
361-
return apply_to_collection(data, torch.Tensor, self._strategy.all_gather, group=group, sync_grads=sync_grads)
361+
return apply_to_collection(data, Tensor, self._strategy.all_gather, group=group, sync_grads=sync_grads)
362362

363363
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
364364
return self._strategy.broadcast(obj, src=src)

src/pytorch_lightning/loggers/comet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
from argparse import Namespace
2222
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
2323

24-
import torch
25-
from torch import is_tensor
24+
from torch import is_tensor, Tensor
2625

2726
import pytorch_lightning as pl
2827
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
@@ -241,7 +240,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
241240
self.experiment.log_parameters(params)
242241

243242
@rank_zero_only
244-
def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None:
243+
def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
245244
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
246245
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
247246
metrics_without_epoch = metrics.copy()

0 commit comments

Comments
 (0)