Skip to content

Commit 2ec67a4

Browse files
tchatoncarmoccaBorda
authored
[bug] Fix Pytorch profiler with emit_nvtx (#6260)
* resolve bug * update changelog * Update tests/trainer/test_trainer.py * Update pytorch_lightning/profiler/profilers.py Co-authored-by: Jirka Borovec <[email protected]> * resolve comments * resolve flake8 Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent e848542 commit 2ec67a4

File tree

5 files changed

+34
-3
lines changed

5 files changed

+34
-3
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8383
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))
8484

8585

86+
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))
87+
88+
89+
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
90+
91+
8692
## [1.2.2] - 2021-03-02
8793

8894
### Added

pytorch_lightning/profiler/profilers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Profiler to check if there are any bottlenecks in your code."""
15-
1615
import cProfile
1716
import io
1817
import logging

pytorch_lightning/profiler/pytorch.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def start(self, action_name: str) -> None:
205205

206206
def _start(self, action_name: str) -> None:
207207
if self.emit_nvtx:
208-
self._create_profiler(action_name, torch.cuda.profiler.profile, enter=False)
208+
self._parent_profiler = self._create_profiler(action_name, torch.cuda.profiler.profile, enter=True)
209209
self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx)
210210
else:
211211
self._create_profiler(action_name, torch.autograd.profiler.profile)
@@ -215,15 +215,24 @@ def _create_profiler(self, action_name, profiler, enter=True):
215215
profiler_args = {k: v for k, v in vars(self).items() if k in init_args}
216216
pr = profiler(**profiler_args)
217217
if enter:
218-
pr = pr.__enter__()
218+
out_pr = pr.__enter__()
219+
if out_pr is not None:
220+
pr = out_pr
219221
self.profiler = pr
222+
return self.profiler
220223

221224
def _stop(self, action_name: str) -> None:
222225
if self.profiler is None:
223226
return
224227

225228
self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None)
226229

230+
if isinstance(self.profiler, torch.autograd.profiler.emit_nvtx):
231+
# when running ``emit_nvtx``, PyTorch requires 2 context manager.
232+
# The parent_profiler is being closed too.
233+
self._parent_profiler.__exit__(None, None, None)
234+
return
235+
227236
function_events = self.profiler.function_events
228237
self.profiler = None
229238
for name in self.running_stack:

tests/special_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_
3232
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp
3333
python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler
3434
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model
35+
nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx

tests/trainer/test_trainer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,6 +1554,22 @@ def test_pytorch_profiler_nested(tmpdir):
15541554
assert pa[n] == expected_[n]
15551555

15561556

1557+
@RunIf(min_gpus=1, special=True)
1558+
def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
1559+
"""
1560+
This test check emit_nvtx is correctly supported
1561+
"""
1562+
profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True)
1563+
1564+
model = BoringModel()
1565+
trainer = Trainer(
1566+
fast_dev_run=True,
1567+
profiler=profiler,
1568+
gpus=1,
1569+
)
1570+
trainer.fit(model)
1571+
1572+
15571573
@pytest.mark.parametrize(
15581574
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
15591575
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)],

0 commit comments

Comments
 (0)