@@ -117,11 +117,11 @@ def pre_step(self, current_action: str) -> None:
117117
118118 def reset (self ):
119119 # handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
120- self ._num_optimizer_step_and_closure = 0
120+ self ._num_optimizer_step_with_closure = 0
121121 self ._num_validation_step = 0
122122 self ._num_test_step = 0
123123 self ._num_predict_step = 0
124- self ._optimizer_step_and_closure_reached_end = False
124+ self ._optimizer_step_with_closure_reached_end = False
125125 self ._validation_step_reached_end = False
126126 self ._test_step_reached_end = False
127127 self ._predict_step_reached_end = False
@@ -132,13 +132,13 @@ def reset(self):
132132 @property
133133 def is_training (self ) -> bool :
134134 return self ._current_action is not None and (
135- self ._current_action .startswith ("optimizer_step_and_closure_ " ) or self ._current_action == "training_step"
135+ self ._current_action .startswith ("optimizer_step_with_closure_ " ) or self ._current_action == "training_step"
136136 )
137137
138138 @property
139139 def num_step (self ) -> int :
140140 if self .is_training :
141- return self ._num_optimizer_step_and_closure
141+ return self ._num_optimizer_step_with_closure
142142 if self ._current_action == "validation_step" :
143143 return self ._num_validation_step
144144 if self ._current_action == "test_step" :
@@ -149,10 +149,10 @@ def num_step(self) -> int:
149149
150150 def _step (self ) -> None :
151151 if self .is_training :
152- self ._num_optimizer_step_and_closure += 1
152+ self ._num_optimizer_step_with_closure += 1
153153 elif self ._current_action == "validation_step" :
154154 if self ._start_action_name == "on_fit_start" :
155- if self ._num_optimizer_step_and_closure > 0 :
155+ if self ._num_optimizer_step_with_closure > 0 :
156156 self ._num_validation_step += 1
157157 else :
158158 self ._num_validation_step += 1
@@ -164,7 +164,7 @@ def _step(self) -> None:
164164 @property
165165 def has_finished (self ) -> bool :
166166 if self .is_training :
167- return self ._optimizer_step_and_closure_reached_end
167+ return self ._optimizer_step_with_closure_reached_end
168168 if self ._current_action == "validation_step" :
169169 return self ._validation_step_reached_end
170170 if self ._current_action == "test_step" :
@@ -182,7 +182,7 @@ def __call__(self, num_step: int) -> "ProfilerAction":
182182 action = self ._schedule (max (self .num_step , 0 ))
183183 if action == ProfilerAction .RECORD_AND_SAVE :
184184 if self .is_training :
185- self ._optimizer_step_and_closure_reached_end = True
185+ self ._optimizer_step_with_closure_reached_end = True
186186 elif self ._current_action == "validation_step" :
187187 self ._validation_step_reached_end = True
188188 elif self ._current_action == "test_step" :
@@ -202,9 +202,9 @@ class PyTorchProfiler(BaseProfiler):
202202 "test_step" ,
203203 "predict_step" ,
204204 }
205- RECORD_FUNCTION_PREFIX = "optimizer_step_and_closure_ "
205+ RECORD_FUNCTION_PREFIX = "optimizer_step_with_closure_ "
206206 STEP_FUNCTIONS = {"training_step" , "validation_step" , "test_step" , "predict_step" }
207- STEP_FUNCTION_PREFIX = "optimizer_step_and_closure_ "
207+ STEP_FUNCTION_PREFIX = "optimizer_step_with_closure_ "
208208 AVAILABLE_SORT_KEYS = {
209209 "cpu_time" ,
210210 "cuda_time" ,
@@ -383,8 +383,8 @@ def start(self, action_name: str) -> None:
383383 self ._register .__enter__ ()
384384
385385 if self ._lightning_module is not None :
386- # when the model is used in automatic optimization,
387- # we use `optimizer_step_and_closure` to step the model.
386+ # when the model is used in automatic optimization, we use `optimizer_step_with_closure` to step the model.
387+ # this profiler event is generated in the `LightningOptimizer.step` method
388388 if self ._lightning_module .automatic_optimization and "training_step" in self .STEP_FUNCTIONS :
389389 self .STEP_FUNCTIONS .remove ("training_step" )
390390
0 commit comments