Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ warn_no_return = "False"
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.data",
Expand Down
69 changes: 44 additions & 25 deletions src/pytorch_lightning/profilers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
from functools import lru_cache, partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.rank_zero import WarningCache
Expand All @@ -42,7 +42,7 @@
log = logging.getLogger(__name__)
warning_cache = WarningCache()

_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx]
_PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx]


class RegisterRecordFunction:
Expand Down Expand Up @@ -111,13 +111,7 @@ def __init__(self, schedule: Callable) -> None:
self._schedule = schedule
self.reset()

def setup(self, start_action_name: str) -> None:
self._start_action_name = start_action_name

def pre_step(self, current_action: str) -> None:
self._current_action = current_action

def reset(self):
def reset(self) -> None:
# handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
self._num_training_step = 0
self._num_validation_step = 0
Expand All @@ -132,20 +126,30 @@ def reset(self):
self._prev_schedule_action: Optional[ProfilerAction] = None
self._start_action_name: Optional[str] = None

def setup(self, start_action_name: str) -> None:
self._start_action_name = start_action_name

def pre_step(self, current_action: str) -> None:
self._current_action = current_action

@property
def is_training(self):
def is_training(self) -> bool:
assert self._current_action is not None
return self._current_action.endswith("training_step")

@property
def is_validating(self):
def is_validating(self) -> bool:
assert self._current_action is not None
return self._current_action.endswith("validation_step")

@property
def is_testing(self):
def is_testing(self) -> bool:
assert self._current_action is not None
return self._current_action.endswith("test_step")

@property
def is_predicting(self):
def is_predicting(self) -> bool:
assert self._current_action is not None
return self._current_action.endswith("predict_step")

@property
Expand All @@ -164,6 +168,7 @@ def _step(self) -> None:
if self.is_training:
self._num_training_step += 1
elif self.is_validating:
assert self._start_action_name is not None
if self._start_action_name.endswith("on_fit_start"):
if self._num_training_step > 0:
self._num_validation_step += 1
Expand Down Expand Up @@ -238,7 +243,7 @@ def __init__(
record_module_names: bool = True,
**profiler_kwargs: Any,
) -> None:
"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.

different operators inside your model - both on the CPU and GPU

Expand Down Expand Up @@ -276,7 +281,7 @@ def __init__(

record_module_names: Whether to add module names while recording autograd operation.

profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
\**profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version

Raises:
MisconfigurationException:
Expand All @@ -298,7 +303,7 @@ def __init__(
self.function_events: Optional["EventList"] = None
self._lightning_module: Optional["LightningModule"] = None # set by ProfilerConnector
self._register: Optional[RegisterRecordFunction] = None
self._parent_profiler: Optional[_PROFILER] = None
self._parent_profiler: Optional[ContextManager] = None
self._recording_map: Dict[str, record_function] = {}
self._start_action_name: Optional[str] = None
self._schedule: Optional[ScheduleWrapper] = None
Expand All @@ -317,7 +322,7 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:

schedule = profiler_kwargs.get("schedule", None)
if schedule is not None:
if not isinstance(schedule, Callable):
if not callable(schedule):
raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}")
action = schedule(0)
if not isinstance(action, ProfilerAction):
Expand All @@ -337,7 +342,9 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
self._profiler_kwargs["with_stack"] = with_stack

@property
def _total_steps(self) -> int:
def _total_steps(self) -> Union[int, float]:
assert self._schedule is not None
assert self._lightning_module is not None
trainer = self._lightning_module.trainer
if self._schedule.is_training:
return trainer.num_training_batches
Expand All @@ -358,13 +365,13 @@ def _should_override_schedule(self) -> bool:

@staticmethod
@lru_cache(1)
def _default_schedule() -> Optional[callable]:
def _default_schedule() -> Optional[Callable]:
if _KINETO_AVAILABLE:
# Those schedule defaults allow the profiling overhead to be negligible over training time.
return torch.profiler.schedule(wait=1, warmup=1, active=3)

def _default_activities(self) -> List["ProfilerActivity"]:
activities = []
activities: List["ProfilerActivity"] = []
if not _KINETO_AVAILABLE:
return activities
if self._profiler_kwargs.get("use_cpu", True):
Expand Down Expand Up @@ -411,6 +418,7 @@ def stop(self, action_name: str) -> None:
return

if self.profiler is not None and any(action_name.endswith(func) for func in self.STEP_FUNCTIONS):
assert isinstance(self.profiler, torch.profiler.profile)
if self._schedule is not None:
self._schedule.pre_step(action_name)

Expand All @@ -424,18 +432,19 @@ def stop(self, action_name: str) -> None:
self._schedule = None
self.profiler.schedule = torch.profiler.profiler._default_schedule_fn

def on_trace_ready(profiler):
def on_trace_ready(profiler: _PROFILER) -> None:
if self.dirpath is not None:
if self._export_to_chrome:
handler = tensorboard_trace_handler(
self.dirpath, self._prepare_filename(action_name=action_name, extension="")
str(self.dirpath), self._prepare_filename(action_name=action_name, extension="")
)
handler(profiler)

if self._export_to_flame_graph:
path = os.path.join(
self.dirpath, self._prepare_filename(action_name=action_name, extension=".stack")
)
assert isinstance(profiler, torch.autograd.profiler.profile)
profiler.export_stacks(path, metric=self._metric)
else:
rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None")
Expand Down Expand Up @@ -469,8 +478,12 @@ def summary(self) -> str:
return self._stats_to_str(recorded_stats)

def _create_profilers(self) -> None:
if self.profiler is not None:
return

if self._emit_nvtx:
self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile)
if self._parent_profiler is None:
self._parent_profiler = torch.cuda.profiler.profile()
self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx)
else:
self._parent_profiler = None
Expand All @@ -486,7 +499,13 @@ def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER:
def _cache_functions_events(self) -> None:
if self._emit_nvtx:
return
self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events

if _KINETO_AVAILABLE:
assert isinstance(self.profiler, torch.profiler.profile)
self.function_events = self.profiler.events()
else:
assert isinstance(self.profiler, torch.autograd.profiler.profile)
self.function_events = self.profiler.function_events

def _delete_profilers(self) -> None:
if self.profiler is not None:
Expand All @@ -505,7 +524,7 @@ def _delete_profilers(self) -> None:
self._register.__exit__(None, None, None)
self._register = None

def teardown(self, stage: str) -> None:
def teardown(self, stage: Optional[str]) -> None:
self._delete_profilers()

for k in list(self._recording_map):
Expand Down