Skip to content

Commit f530489

Browse files
authored
Deprecate BaseProfiler.profile_iterable (#12102)
1 parent 61dd5e4 commit f530489

File tree

4 files changed

+56
-27
lines changed

4 files changed

+56
-27
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
430430
- Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))
431431

432432

433+
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
434+
435+
433436
### Removed
434437

435438
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

pytorch_lightning/profiler/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union
2121

2222
from pytorch_lightning.utilities.cloud_io import get_filesystem
23+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2324

2425
log = logging.getLogger(__name__)
2526

@@ -83,6 +84,14 @@ def profile(self, action_name: str) -> Generator:
8384
self.stop(action_name)
8485

8586
def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator:
87+
"""Profiles over each value of an iterable.
88+
89+
See deprecation message below.
90+
91+
.. deprecated:: v1.6
92+
`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.
93+
"""
94+
rank_zero_deprecation("`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.")
8695
iterator = iter(iterable)
8796
while True:
8897
try:

tests/deprecated_api/test_remove_1-8.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Test deprecated functionality which will be removed in v1.8.0."""
15+
import time
1516
from unittest.mock import Mock
1617

1718
import numpy as np
@@ -33,6 +34,7 @@
3334
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
3435
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
3536
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
37+
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
3638
from pytorch_lightning.trainer.states import RunningStage
3739
from pytorch_lightning.utilities.apply_func import move_data_to_device
3840
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
@@ -608,3 +610,45 @@ def on_pretrain_routine_end(self, trainer, pl_module):
608610
match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8"
609611
):
610612
trainer.fit(model)
613+
614+
615+
@pytest.mark.flaky(reruns=3)
616+
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
617+
def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list):
618+
"""Ensure the reported durations are reasonably accurate."""
619+
620+
def _sleep_generator(durations):
621+
"""the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long
622+
it takes to call __next__"""
623+
for duration in durations:
624+
time.sleep(duration)
625+
yield duration
626+
627+
def _get_python_cprofile_total_duration(profile):
628+
return sum(x.inlinetime for x in profile.getstats())
629+
630+
simple_profiler = SimpleProfiler()
631+
iterable = _sleep_generator(expected)
632+
633+
with pytest.deprecated_call(
634+
match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8."
635+
):
636+
for _ in simple_profiler.profile_iterable(iterable, action):
637+
pass
638+
639+
# we exclude the last item in the recorded durations since that's when StopIteration is raised
640+
np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2)
641+
642+
advanced_profiler = AdvancedProfiler(dirpath=tmpdir, filename="profiler")
643+
644+
iterable = _sleep_generator(expected)
645+
646+
with pytest.deprecated_call(
647+
match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8."
648+
):
649+
for _ in advanced_profiler.profile_iterable(iterable, action):
650+
pass
651+
652+
recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
653+
expected_total_duration = np.sum(expected)
654+
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)

tests/profiler/test_profiler.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,6 @@ def test_simple_profiler_durations(simple_profiler, action: str, expected: list)
6767
np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2)
6868

6969

70-
@pytest.mark.flaky(reruns=3)
71-
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
72-
def test_simple_profiler_iterable_durations(simple_profiler, action: str, expected: list):
73-
"""Ensure the reported durations are reasonably accurate."""
74-
iterable = _sleep_generator(expected)
75-
76-
for _ in simple_profiler.profile_iterable(iterable, action):
77-
pass
78-
79-
# we exclude the last item in the recorded durations since that's when StopIteration is raised
80-
np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2)
81-
82-
8370
def test_simple_profiler_overhead(simple_profiler, n_iter=5):
8471
"""Ensure that the profiler doesn't introduce too much overhead during training."""
8572
for _ in range(n_iter):
@@ -289,20 +276,6 @@ def test_advanced_profiler_durations(advanced_profiler, action: str, expected: l
289276
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
290277

291278

292-
@pytest.mark.flaky(reruns=3)
293-
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
294-
def test_advanced_profiler_iterable_durations(advanced_profiler, action: str, expected: list):
295-
"""Ensure the reported durations are reasonably accurate."""
296-
iterable = _sleep_generator(expected)
297-
298-
for _ in advanced_profiler.profile_iterable(iterable, action):
299-
pass
300-
301-
recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
302-
expected_total_duration = np.sum(expected)
303-
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
304-
305-
306279
@pytest.mark.flaky(reruns=3)
307280
def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
308281
"""ensure that the profiler doesn't introduce too much overhead during training."""

0 commit comments

Comments
 (0)