|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | """Test deprecated functionality which will be removed in v1.8.0.""" |
| 15 | +import time |
15 | 16 | from unittest.mock import Mock |
16 | 17 |
|
17 | 18 | import numpy as np |
|
33 | 34 | from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin |
34 | 35 | from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin |
35 | 36 | from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin |
| 37 | +from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler |
36 | 38 | from pytorch_lightning.trainer.states import RunningStage |
37 | 39 | from pytorch_lightning.utilities.apply_func import move_data_to_device |
38 | 40 | from pytorch_lightning.utilities.enums import DeviceType, DistributedType |
@@ -608,3 +610,45 @@ def on_pretrain_routine_end(self, trainer, pl_module): |
608 | 610 | match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8" |
609 | 611 | ): |
610 | 612 | trainer.fit(model) |
| 613 | + |
| 614 | + |
| 615 | +@pytest.mark.flaky(reruns=3) |
| 616 | +@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])]) |
| 617 | +def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list): |
| 618 | + """Ensure the reported durations are reasonably accurate.""" |
| 619 | + |
| 620 | + def _sleep_generator(durations): |
| 621 | + """the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long |
| 622 | + it takes to call __next__""" |
| 623 | + for duration in durations: |
| 624 | + time.sleep(duration) |
| 625 | + yield duration |
| 626 | + |
| 627 | + def _get_python_cprofile_total_duration(profile): |
| 628 | + return sum(x.inlinetime for x in profile.getstats()) |
| 629 | + |
| 630 | + simple_profiler = SimpleProfiler() |
| 631 | + iterable = _sleep_generator(expected) |
| 632 | + |
| 633 | + with pytest.deprecated_call( |
| 634 | + match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8." |
| 635 | + ): |
| 636 | + for _ in simple_profiler.profile_iterable(iterable, action): |
| 637 | + pass |
| 638 | + |
| 639 | + # we exclude the last item in the recorded durations since that's when StopIteration is raised |
| 640 | + np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2) |
| 641 | + |
| 642 | + advanced_profiler = AdvancedProfiler(dirpath=tmpdir, filename="profiler") |
| 643 | + |
| 644 | + iterable = _sleep_generator(expected) |
| 645 | + |
| 646 | + with pytest.deprecated_call( |
| 647 | + match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8." |
| 648 | + ): |
| 649 | + for _ in advanced_profiler.profile_iterable(iterable, action): |
| 650 | + pass |
| 651 | + |
| 652 | + recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action]) |
| 653 | + expected_total_duration = np.sum(expected) |
| 654 | + np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2) |
0 commit comments