Skip to content

Commit f68c090

Browse files
Fix mypy errors attributed to pytorch_lightning.profilers.pytorch (#14405)
* remove toml ref * fix conflicts * small fix * move assertion Co-authored-by: rohitgr7 <[email protected]>
1 parent c81a71c commit f68c090

File tree

2 files changed

+44
-26
lines changed

2 files changed

+44
-26
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ warn_no_return = "False"
5252
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
5353
module = [
5454
"pytorch_lightning.callbacks.progress.rich_progress",
55-
"pytorch_lightning.profilers.pytorch",
5655
"pytorch_lightning.trainer.trainer",
5756
"pytorch_lightning.tuner.batch_size_scaling",
5857
"pytorch_lightning.utilities.data",

src/pytorch_lightning/profilers/pytorch.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
from functools import lru_cache, partial
1919
from pathlib import Path
20-
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
20+
from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TYPE_CHECKING, Union
2121

2222
import torch
2323
from lightning_utilities.core.rank_zero import WarningCache
@@ -42,7 +42,7 @@
4242
log = logging.getLogger(__name__)
4343
warning_cache = WarningCache()
4444

45-
_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx]
45+
_PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx]
4646

4747

4848
class RegisterRecordFunction:
@@ -111,13 +111,7 @@ def __init__(self, schedule: Callable) -> None:
111111
self._schedule = schedule
112112
self.reset()
113113

114-
def setup(self, start_action_name: str) -> None:
115-
self._start_action_name = start_action_name
116-
117-
def pre_step(self, current_action: str) -> None:
118-
self._current_action = current_action
119-
120-
def reset(self):
114+
def reset(self) -> None:
121115
# handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
122116
self._num_training_step = 0
123117
self._num_validation_step = 0
@@ -132,20 +126,30 @@ def reset(self):
132126
self._prev_schedule_action: Optional[ProfilerAction] = None
133127
self._start_action_name: Optional[str] = None
134128

129+
def setup(self, start_action_name: str) -> None:
130+
self._start_action_name = start_action_name
131+
132+
def pre_step(self, current_action: str) -> None:
133+
self._current_action = current_action
134+
135135
@property
136-
def is_training(self):
136+
def is_training(self) -> bool:
137+
assert self._current_action is not None
137138
return self._current_action.endswith("training_step")
138139

139140
@property
140-
def is_validating(self):
141+
def is_validating(self) -> bool:
142+
assert self._current_action is not None
141143
return self._current_action.endswith("validation_step")
142144

143145
@property
144-
def is_testing(self):
146+
def is_testing(self) -> bool:
147+
assert self._current_action is not None
145148
return self._current_action.endswith("test_step")
146149

147150
@property
148-
def is_predicting(self):
151+
def is_predicting(self) -> bool:
152+
assert self._current_action is not None
149153
return self._current_action.endswith("predict_step")
150154

151155
@property
@@ -164,6 +168,7 @@ def _step(self) -> None:
164168
if self.is_training:
165169
self._num_training_step += 1
166170
elif self.is_validating:
171+
assert self._start_action_name is not None
167172
if self._start_action_name.endswith("on_fit_start"):
168173
if self._num_training_step > 0:
169174
self._num_validation_step += 1
@@ -238,7 +243,7 @@ def __init__(
238243
record_module_names: bool = True,
239244
**profiler_kwargs: Any,
240245
) -> None:
241-
"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
246+
r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
242247
243248
different operators inside your model - both on the CPU and GPU
244249
@@ -276,7 +281,7 @@ def __init__(
276281
277282
record_module_names: Whether to add module names while recording autograd operation.
278283
279-
profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
284+
\**profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
280285
281286
Raises:
282287
MisconfigurationException:
@@ -298,7 +303,7 @@ def __init__(
298303
self.function_events: Optional["EventList"] = None
299304
self._lightning_module: Optional["LightningModule"] = None # set by ProfilerConnector
300305
self._register: Optional[RegisterRecordFunction] = None
301-
self._parent_profiler: Optional[_PROFILER] = None
306+
self._parent_profiler: Optional[ContextManager] = None
302307
self._recording_map: Dict[str, record_function] = {}
303308
self._start_action_name: Optional[str] = None
304309
self._schedule: Optional[ScheduleWrapper] = None
@@ -317,7 +322,7 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
317322

318323
schedule = profiler_kwargs.get("schedule", None)
319324
if schedule is not None:
320-
if not isinstance(schedule, Callable):
325+
if not callable(schedule):
321326
raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}")
322327
action = schedule(0)
323328
if not isinstance(action, ProfilerAction):
@@ -337,7 +342,9 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
337342
self._profiler_kwargs["with_stack"] = with_stack
338343

339344
@property
340-
def _total_steps(self) -> int:
345+
def _total_steps(self) -> Union[int, float]:
346+
assert self._schedule is not None
347+
assert self._lightning_module is not None
341348
trainer = self._lightning_module.trainer
342349
if self._schedule.is_training:
343350
return trainer.num_training_batches
@@ -358,13 +365,13 @@ def _should_override_schedule(self) -> bool:
358365

359366
@staticmethod
360367
@lru_cache(1)
361-
def _default_schedule() -> Optional[callable]:
368+
def _default_schedule() -> Optional[Callable]:
362369
if _KINETO_AVAILABLE:
363370
# Those schedule defaults allow the profiling overhead to be negligible over training time.
364371
return torch.profiler.schedule(wait=1, warmup=1, active=3)
365372

366373
def _default_activities(self) -> List["ProfilerActivity"]:
367-
activities = []
374+
activities: List["ProfilerActivity"] = []
368375
if not _KINETO_AVAILABLE:
369376
return activities
370377
if self._profiler_kwargs.get("use_cpu", True):
@@ -411,6 +418,7 @@ def stop(self, action_name: str) -> None:
411418
return
412419

413420
if self.profiler is not None and any(action_name.endswith(func) for func in self.STEP_FUNCTIONS):
421+
assert isinstance(self.profiler, torch.profiler.profile)
414422
if self._schedule is not None:
415423
self._schedule.pre_step(action_name)
416424

@@ -424,18 +432,19 @@ def stop(self, action_name: str) -> None:
424432
self._schedule = None
425433
self.profiler.schedule = torch.profiler.profiler._default_schedule_fn
426434

427-
def on_trace_ready(profiler):
435+
def on_trace_ready(profiler: _PROFILER) -> None:
428436
if self.dirpath is not None:
429437
if self._export_to_chrome:
430438
handler = tensorboard_trace_handler(
431-
self.dirpath, self._prepare_filename(action_name=action_name, extension="")
439+
str(self.dirpath), self._prepare_filename(action_name=action_name, extension="")
432440
)
433441
handler(profiler)
434442

435443
if self._export_to_flame_graph:
436444
path = os.path.join(
437445
self.dirpath, self._prepare_filename(action_name=action_name, extension=".stack")
438446
)
447+
assert isinstance(profiler, torch.autograd.profiler.profile)
439448
profiler.export_stacks(path, metric=self._metric)
440449
else:
441450
rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None")
@@ -469,8 +478,12 @@ def summary(self) -> str:
469478
return self._stats_to_str(recorded_stats)
470479

471480
def _create_profilers(self) -> None:
481+
if self.profiler is not None:
482+
return
483+
472484
if self._emit_nvtx:
473-
self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile)
485+
if self._parent_profiler is None:
486+
self._parent_profiler = torch.cuda.profiler.profile()
474487
self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx)
475488
else:
476489
self._parent_profiler = None
@@ -486,7 +499,13 @@ def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER:
486499
def _cache_functions_events(self) -> None:
487500
if self._emit_nvtx:
488501
return
489-
self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events
502+
503+
if _KINETO_AVAILABLE:
504+
assert isinstance(self.profiler, torch.profiler.profile)
505+
self.function_events = self.profiler.events()
506+
else:
507+
assert isinstance(self.profiler, torch.autograd.profiler.profile)
508+
self.function_events = self.profiler.function_events
490509

491510
def _delete_profilers(self) -> None:
492511
if self.profiler is not None:
@@ -505,7 +524,7 @@ def _delete_profilers(self) -> None:
505524
self._register.__exit__(None, None, None)
506525
self._register = None
507526

508-
def teardown(self, stage: str) -> None:
527+
def teardown(self, stage: Optional[str]) -> None:
509528
self._delete_profilers()
510529

511530
for k in list(self._recording_map):

0 commit comments

Comments
 (0)