diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index 0b486f1aa587d..11c30d1fb2e63 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -22,25 +22,23 @@ import torch from torch import nn, Tensor from torch.autograd.profiler import EventList, record_function +from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler +from torch.utils.hooks import RemovableHandle from lightning.fabric.accelerators.cuda import is_cuda_available from lightning.pytorch.profilers.profiler import Profiler from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _KINETO_AVAILABLE from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache if TYPE_CHECKING: - from torch.utils.hooks import RemovableHandle - from lightning.pytorch.core.module import LightningModule -if _KINETO_AVAILABLE: - from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler log = logging.getLogger(__name__) warning_cache = WarningCache() _PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx] +_KINETO_AVAILABLE = torch.profiler.kineto_available() class RegisterRecordFunction: @@ -65,7 +63,7 @@ class RegisterRecordFunction: def __init__(self, model: nn.Module) -> None: self._model = model self._records: Dict[str, record_function] = {} - self._handles: Dict[str, List["RemovableHandle"]] = {} + self._handles: Dict[str, List[RemovableHandle]] = {} def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: # Add [pl][module] in name for pytorch profiler to recognize diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 778f3ef70697b..e2b40f98ec44b 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -15,7 +15,6 @@ import functools import sys -import torch from lightning_utilities.core.imports import package_available, RequirementCache from lightning_utilities.core.rank_zero import rank_zero_warn @@ -23,7 +22,6 @@ _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task -_KINETO_AVAILABLE = torch.profiler.kineto_available() _OMEGACONF_AVAILABLE = package_available("omegaconf") _TORCHVISION_AVAILABLE = RequirementCache("torchvision") _LIGHTNING_COLOSSALAI_AVAILABLE = RequirementCache("lightning-colossalai") diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index 00595bfd25cba..3540b86317a9e 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -27,9 +27,8 @@ from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler -from lightning.pytorch.profilers.pytorch import RegisterRecordFunction, warning_cache +from lightning.pytorch.profilers.pytorch import _KINETO_AVAILABLE, RegisterRecordFunction, warning_cache from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _KINETO_AVAILABLE from tests_pytorch.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005