Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/lightning/pytorch/profilers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,23 @@
import torch
from torch import nn, Tensor
from torch.autograd.profiler import EventList, record_function
from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler
from torch.utils.hooks import RemovableHandle

from lightning.fabric.accelerators.cuda import is_cuda_available
from lightning.pytorch.profilers.profiler import Profiler
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _KINETO_AVAILABLE
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache

if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle

from lightning.pytorch.core.module import LightningModule

if _KINETO_AVAILABLE:
from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler

log = logging.getLogger(__name__)
warning_cache = WarningCache()

_PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx]
_KINETO_AVAILABLE = torch.profiler.kineto_available()


class RegisterRecordFunction:
Expand All @@ -65,7 +63,7 @@ class RegisterRecordFunction:
def __init__(self, model: nn.Module) -> None:
self._model = model
self._records: Dict[str, record_function] = {}
self._handles: Dict[str, List["RemovableHandle"]] = {}
self._handles: Dict[str, List[RemovableHandle]] = {}

def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor:
# Add [pl][module] in name for pytorch profiler to recognize
Expand Down
2 changes: 0 additions & 2 deletions src/lightning/pytorch/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
import functools
import sys

import torch
from lightning_utilities.core.imports import package_available, RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_warn

_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task

_KINETO_AVAILABLE = torch.profiler.kineto_available()
_OMEGACONF_AVAILABLE = package_available("omegaconf")
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
_LIGHTNING_COLOSSALAI_AVAILABLE = RequirementCache("lightning-colossalai")
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/profilers/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from lightning.pytorch.profilers.pytorch import RegisterRecordFunction, warning_cache
from lightning.pytorch.profilers.pytorch import _KINETO_AVAILABLE, RegisterRecordFunction, warning_cache
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _KINETO_AVAILABLE
from tests_pytorch.helpers.runif import RunIf

PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005
Expand Down