1717import os
1818from functools import lru_cache , partial
1919from pathlib import Path
20- from typing import Any , Callable , Dict , List , Optional , Type , TYPE_CHECKING , Union
20+ from typing import Any , Callable , ContextManager , Dict , List , Optional , Type , TYPE_CHECKING , Union
2121
2222import torch
2323from lightning_utilities .core .rank_zero import WarningCache
4242log = logging .getLogger (__name__ )
4343warning_cache = WarningCache ()
4444
45- _PROFILER = Union [torch .autograd . profiler .profile , torch .cuda .profiler .profile , torch .autograd .profiler .emit_nvtx ]
45+ _PROFILER = Union [torch .profiler .profile , torch .autograd .profiler .profile , torch .autograd .profiler .emit_nvtx ]
4646
4747
4848class RegisterRecordFunction :
@@ -111,13 +111,7 @@ def __init__(self, schedule: Callable) -> None:
111111 self ._schedule = schedule
112112 self .reset ()
113113
114- def setup (self , start_action_name : str ) -> None :
115- self ._start_action_name = start_action_name
116-
117- def pre_step (self , current_action : str ) -> None :
118- self ._current_action = current_action
119-
120- def reset (self ):
114+ def reset (self ) -> None :
121115 # handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
122116 self ._num_training_step = 0
123117 self ._num_validation_step = 0
@@ -132,20 +126,30 @@ def reset(self):
132126 self ._prev_schedule_action : Optional [ProfilerAction ] = None
133127 self ._start_action_name : Optional [str ] = None
134128
129+ def setup (self , start_action_name : str ) -> None :
130+ self ._start_action_name = start_action_name
131+
132+ def pre_step (self , current_action : str ) -> None :
133+ self ._current_action = current_action
134+
135135 @property
136- def is_training (self ):
136+ def is_training (self ) -> bool :
137+ assert self ._current_action is not None
137138 return self ._current_action .endswith ("training_step" )
138139
139140 @property
140- def is_validating (self ):
141+ def is_validating (self ) -> bool :
142+ assert self ._current_action is not None
141143 return self ._current_action .endswith ("validation_step" )
142144
143145 @property
144- def is_testing (self ):
146+ def is_testing (self ) -> bool :
147+ assert self ._current_action is not None
145148 return self ._current_action .endswith ("test_step" )
146149
147150 @property
148- def is_predicting (self ):
151+ def is_predicting (self ) -> bool :
152+ assert self ._current_action is not None
149153 return self ._current_action .endswith ("predict_step" )
150154
151155 @property
@@ -164,6 +168,7 @@ def _step(self) -> None:
164168 if self .is_training :
165169 self ._num_training_step += 1
166170 elif self .is_validating :
171+ assert self ._start_action_name is not None
167172 if self ._start_action_name .endswith ("on_fit_start" ):
168173 if self ._num_training_step > 0 :
169174 self ._num_validation_step += 1
@@ -238,7 +243,7 @@ def __init__(
238243 record_module_names : bool = True ,
239244 ** profiler_kwargs : Any ,
240245 ) -> None :
241- """This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
246+ r """This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
242247
243248 different operators inside your model - both on the CPU and GPU
244249
@@ -276,7 +281,7 @@ def __init__(
276281
277282 record_module_names: Whether to add module names while recording autograd operation.
278283
279- profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
284+ \** profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
280285
281286 Raises:
282287 MisconfigurationException:
@@ -298,7 +303,7 @@ def __init__(
298303 self .function_events : Optional ["EventList" ] = None
299304 self ._lightning_module : Optional ["LightningModule" ] = None # set by ProfilerConnector
300305 self ._register : Optional [RegisterRecordFunction ] = None
301- self ._parent_profiler : Optional [_PROFILER ] = None
306+ self ._parent_profiler : Optional [ContextManager ] = None
302307 self ._recording_map : Dict [str , record_function ] = {}
303308 self ._start_action_name : Optional [str ] = None
304309 self ._schedule : Optional [ScheduleWrapper ] = None
@@ -317,7 +322,7 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
317322
318323 schedule = profiler_kwargs .get ("schedule" , None )
319324 if schedule is not None :
320- if not isinstance (schedule , Callable ):
325+ if not callable (schedule ):
321326 raise MisconfigurationException (f"Schedule should be a callable. Found: { schedule } " )
322327 action = schedule (0 )
323328 if not isinstance (action , ProfilerAction ):
@@ -337,7 +342,9 @@ def _init_kineto(self, profiler_kwargs: Any) -> None:
337342 self ._profiler_kwargs ["with_stack" ] = with_stack
338343
339344 @property
340- def _total_steps (self ) -> int :
345+ def _total_steps (self ) -> Union [int , float ]:
346+ assert self ._schedule is not None
347+ assert self ._lightning_module is not None
341348 trainer = self ._lightning_module .trainer
342349 if self ._schedule .is_training :
343350 return trainer .num_training_batches
@@ -358,13 +365,13 @@ def _should_override_schedule(self) -> bool:
358365
359366 @staticmethod
360367 @lru_cache (1 )
361- def _default_schedule () -> Optional [callable ]:
368+ def _default_schedule () -> Optional [Callable ]:
362369 if _KINETO_AVAILABLE :
363370 # Those schedule defaults allow the profiling overhead to be negligible over training time.
364371 return torch .profiler .schedule (wait = 1 , warmup = 1 , active = 3 )
365372
366373 def _default_activities (self ) -> List ["ProfilerActivity" ]:
367- activities = []
374+ activities : List [ "ProfilerActivity" ] = []
368375 if not _KINETO_AVAILABLE :
369376 return activities
370377 if self ._profiler_kwargs .get ("use_cpu" , True ):
@@ -411,6 +418,7 @@ def stop(self, action_name: str) -> None:
411418 return
412419
413420 if self .profiler is not None and any (action_name .endswith (func ) for func in self .STEP_FUNCTIONS ):
421+ assert isinstance (self .profiler , torch .profiler .profile )
414422 if self ._schedule is not None :
415423 self ._schedule .pre_step (action_name )
416424
@@ -424,18 +432,19 @@ def stop(self, action_name: str) -> None:
424432 self ._schedule = None
425433 self .profiler .schedule = torch .profiler .profiler ._default_schedule_fn
426434
427- def on_trace_ready (profiler ) :
435+ def on_trace_ready (profiler : _PROFILER ) -> None :
428436 if self .dirpath is not None :
429437 if self ._export_to_chrome :
430438 handler = tensorboard_trace_handler (
431- self .dirpath , self ._prepare_filename (action_name = action_name , extension = "" )
439+ str ( self .dirpath ) , self ._prepare_filename (action_name = action_name , extension = "" )
432440 )
433441 handler (profiler )
434442
435443 if self ._export_to_flame_graph :
436444 path = os .path .join (
437445 self .dirpath , self ._prepare_filename (action_name = action_name , extension = ".stack" )
438446 )
447+ assert isinstance (profiler , torch .autograd .profiler .profile )
439448 profiler .export_stacks (path , metric = self ._metric )
440449 else :
441450 rank_zero_warn ("The PyTorchProfiler failed to export trace as `dirpath` is None" )
@@ -469,8 +478,12 @@ def summary(self) -> str:
469478 return self ._stats_to_str (recorded_stats )
470479
471480 def _create_profilers (self ) -> None :
481+ if self .profiler is not None :
482+ return
483+
472484 if self ._emit_nvtx :
473- self ._parent_profiler = self ._create_profiler (torch .cuda .profiler .profile )
485+ if self ._parent_profiler is None :
486+ self ._parent_profiler = torch .cuda .profiler .profile ()
474487 self .profiler = self ._create_profiler (torch .autograd .profiler .emit_nvtx )
475488 else :
476489 self ._parent_profiler = None
@@ -486,7 +499,13 @@ def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER:
486499 def _cache_functions_events (self ) -> None :
487500 if self ._emit_nvtx :
488501 return
489- self .function_events = self .profiler .events () if _KINETO_AVAILABLE else self .profiler .function_events
502+
503+ if _KINETO_AVAILABLE :
504+ assert isinstance (self .profiler , torch .profiler .profile )
505+ self .function_events = self .profiler .events ()
506+ else :
507+ assert isinstance (self .profiler , torch .autograd .profiler .profile )
508+ self .function_events = self .profiler .function_events
490509
491510 def _delete_profilers (self ) -> None :
492511 if self .profiler is not None :
@@ -505,7 +524,7 @@ def _delete_profilers(self) -> None:
505524 self ._register .__exit__ (None , None , None )
506525 self ._register = None
507526
508- def teardown (self , stage : str ) -> None :
527+ def teardown (self , stage : Optional [ str ] ) -> None :
509528 self ._delete_profilers ()
510529
511530 for k in list (self ._recording_map ):
0 commit comments