|
1 | 1 | import contextlib |
| 2 | +import os |
2 | 3 | import random |
| 4 | +from unittest import mock |
3 | 5 | from unittest.mock import Mock |
4 | 6 |
|
5 | 7 | import numpy as np |
6 | 8 | import pytest |
7 | 9 | import torch |
| 10 | +from lightning_utilities.test.warning import no_warning_call |
8 | 11 | from torch import Tensor |
9 | 12 | from torch.utils.data import BatchSampler, DataLoader, RandomSampler |
10 | 13 |
|
| 14 | +import lightning.fabric |
11 | 15 | from lightning.fabric.utilities.data import ( |
12 | 16 | _get_dataloader_init_args_and_kwargs, |
13 | 17 | _replace_dunder_methods, |
|
17 | 21 | _WrapAttrTag, |
18 | 22 | has_iterable_dataset, |
19 | 23 | has_len, |
| 24 | + suggested_max_num_workers, |
20 | 25 | ) |
21 | 26 | from lightning.fabric.utilities.exceptions import MisconfigurationException |
22 | 27 | from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset |
@@ -575,3 +580,63 @@ def test_set_sampler_epoch(): |
575 | 580 | _set_sampler_epoch(dataloader, 55) |
576 | 581 | dataloader.sampler.set_epoch.assert_called_once_with(55) |
577 | 582 | dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55) |
| 583 | + |
| 584 | + |
| 585 | +@pytest.mark.parametrize( |
| 586 | + ("cpu_count", "local_world_size", "expected"), |
| 587 | + [ |
| 588 | + (0, 1, 1), |
| 589 | + (1, 1, 1), |
| 590 | + (2, 1, 2), |
| 591 | + (1, 2, 1), |
| 592 | + (1, 2, 1), |
| 593 | + (2, 2, 1), |
| 594 | + (3, 2, 1), |
| 595 | + (4, 2, 2), |
| 596 | + (4, 3, 1), |
| 597 | + (4, 1, 4), |
| 598 | + ], |
| 599 | +) |
| 600 | +@pytest.mark.parametrize( |
| 601 | + "affinity", |
| 602 | + [ |
| 603 | + False, |
| 604 | + pytest.param( |
| 605 | + True, |
| 606 | + marks=pytest.mark.skipif( |
| 607 | + not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores" |
| 608 | + ), |
| 609 | + ), |
| 610 | + ], |
| 611 | +) |
| 612 | +@mock.patch("lightning.fabric.utilities.data.os.cpu_count") |
| 613 | +def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch): |
| 614 | + if affinity: |
| 615 | + monkeypatch.setattr(lightning.fabric.utilities.data.os, "sched_getaffinity", lambda _: list(range(cpu_count))) |
| 616 | + else: |
| 617 | + monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) |
| 618 | + cpu_count_mock.return_value = cpu_count |
| 619 | + |
| 620 | + assert suggested_max_num_workers(local_world_size) == expected |
| 621 | + |
| 622 | + |
| 623 | +@pytest.mark.parametrize("invalid", [-1, 0]) |
| 624 | +def test_suggested_max_num_workers_input_validation(invalid): |
| 625 | + with pytest.raises(ValueError, match="should be >= 1"): |
| 626 | + suggested_max_num_workers(invalid) |
| 627 | + |
| 628 | + |
| 629 | +@pytest.mark.parametrize("cpu_count", [1, 2, 3]) |
| 630 | +@pytest.mark.parametrize("local_world_size", [1, 2, 3]) |
| 631 | +def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size, cpu_count, monkeypatch): |
| 632 | + """Test that our suggestion for num workers doesn't trigger a warning in the DataLoader for too many workers.""" |
| 633 | + monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) |
| 634 | + monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False) |
| 635 | + monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count) |
| 636 | + monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count) |
| 637 | + |
| 638 | + # The dataloader runs a check in `DataLoader.check_worker_number_rationality` |
| 639 | + with pytest.warns(UserWarning, match="This DataLoader will create"): |
| 640 | + DataLoader(range(2), num_workers=(cpu_count + 1)) |
| 641 | + with no_warning_call(): |
| 642 | + DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size)) |
0 commit comments