Skip to content

Commit 21fc5eb

Browse files
authored
Automatically find and run special tests (#6669)
1 parent b730a5a commit 21fc5eb

File tree

7 files changed

+109
-164
lines changed

7 files changed

+109
-164
lines changed

azure-pipelines.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
displayName: 'Testing: standard'
8383
8484
- bash: |
85-
sh tests/special_tests.sh
85+
bash tests/special_tests.sh
8686
displayName: 'Testing: special'
8787
8888
- bash: |

benchmarks/test_sharded_parity.py

Lines changed: 43 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
import time
1716
from typing import Type
1817

@@ -21,113 +20,13 @@
2120

2221
from pytorch_lightning import seed_everything, Trainer
2322
from pytorch_lightning.plugins import DDPSpawnShardedPlugin
24-
from tests.accelerators import DDPLauncher
2523
from tests.helpers.boring_model import BoringModel, RandomDataset
2624
from tests.helpers.runif import RunIf
2725

2826

29-
@RunIf(min_gpus=1, skip_windows=True, fairscale=True)
30-
def test_ddp_sharded_plugin_correctness_one_gpu():
31-
plugin_parity_test(
32-
gpus=1,
33-
model_cls=SeedTrainLoaderModel,
34-
)
35-
36-
37-
@RunIf(min_gpus=1, skip_windows=True, fairscale=True, amp_native=True)
38-
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
39-
plugin_parity_test(
40-
gpus=1,
41-
precision=16,
42-
model_cls=SeedTrainLoaderModel,
43-
)
44-
45-
46-
@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
47-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
48-
def test_ddp_sharded_plugin_correctness_multi_gpu():
49-
plugin_parity_test(
50-
gpus=2,
51-
model_cls=SeedTrainLoaderModel,
52-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
53-
)
54-
55-
56-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
57-
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
58-
plugin_parity_test(
59-
gpus=2,
60-
precision=16,
61-
model_cls=SeedTrainLoaderModel,
62-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
63-
)
64-
65-
66-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
67-
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
68-
plugin_parity_test(
69-
gpus=2,
70-
precision=16,
71-
model_cls=SeedTrainLoaderModel,
72-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
73-
)
74-
75-
76-
@RunIf(min_gpus=2, fairscale=True)
77-
@pytest.mark.skipif(
78-
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
79-
)
80-
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 32")
81-
def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
82-
plugin_parity_test(
83-
gpus=args.gpus,
84-
precision=args.precision,
85-
model_cls=SeedTrainLoaderModel,
86-
)
87-
88-
89-
@RunIf(min_gpus=2, fairscale=True)
90-
@pytest.mark.skipif(
91-
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
92-
)
93-
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16")
94-
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
95-
plugin_parity_test(
96-
gpus=args.gpus,
97-
precision=args.precision,
98-
model_cls=SeedTrainLoaderModel,
99-
)
100-
101-
102-
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
103-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
104-
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
105-
"""
106-
Ensures same results using multiple optimizers across multiple GPUs
107-
"""
108-
plugin_parity_test(
109-
gpus=2,
110-
model_cls=SeedTrainLoaderMultipleOptimizersModel,
111-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
112-
)
113-
114-
115-
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
116-
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
117-
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
118-
"""
119-
Ensures using multiple optimizers across multiple GPUs with manual optimization
120-
"""
121-
plugin_parity_test(
122-
gpus=2,
123-
model_cls=SeedTrainLoaderManualModel,
124-
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
125-
)
126-
127-
12827
class SeedTrainLoaderModel(BoringModel):
12928
"""
130-
Overrides training loader to ensure we enforce the same seed for all DDP processes.
29+
Overrides training loader to ensure we enforce the same seed for all DDP processes.
13130
"""
13231

13332
def train_dataloader(self):
@@ -177,7 +76,7 @@ class SeedTrainLoaderMultipleOptimizersModel(SeedTrainLoaderModel):
17776
def training_step(self, batch, batch_idx, optimizer_idx):
17877
output = self.layer(batch)
17978
loss = self.loss(batch, output)
180-
return {"loss": loss}
79+
return {'loss': loss}
18180

18281
def training_epoch_end(self, outputs) -> None:
18382
# outputs should be an array with an entry per optimizer
@@ -279,11 +178,48 @@ def plugin_parity_test(
279178
# Assert speed parity by ensuring percentage difference between custom/ddp is below threshold
280179
percent_diff = (custom_model_time - ddp_time) / custom_model_time
281180

282-
assert percent_diff <= max_percent_speed_diff, \
283-
f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}'
181+
assert (
182+
percent_diff <= max_percent_speed_diff
183+
), f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}'
284184

285185
if use_cuda:
286186
# Assert CUDA memory parity
287-
assert max_memory_custom <= max_memory_ddp, \
288-
f'Custom plugin used too much memory compared to DDP,' \
187+
assert max_memory_custom <= max_memory_ddp, (
188+
'Custom plugin used too much memory compared to DDP, '
289189
f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}'
190+
)
191+
192+
193+
@RunIf(skip_windows=True, fairscale=True)
194+
@pytest.mark.parametrize(
195+
'kwargs',
196+
[
197+
pytest.param(dict(gpus=1, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=1)),
198+
pytest.param(
199+
dict(gpus=1, precision=16, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=1, amp_native=True)
200+
),
201+
pytest.param(dict(gpus=2, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=2)),
202+
pytest.param(
203+
dict(gpus=2, precision=16, model_cls=SeedTrainLoaderModel), marks=RunIf(min_gpus=2, amp_native=True)
204+
),
205+
pytest.param(
206+
dict(gpus=2, model_cls=SeedTrainLoaderMultipleOptimizersModel),
207+
marks=[
208+
RunIf(min_gpus=2),
209+
pytest.mark.skip(reason='TODO: Current issue with multiple optimizers and FairScale.'),
210+
],
211+
),
212+
pytest.param(
213+
dict(gpus=2, model_cls=SeedTrainLoaderManualModel),
214+
marks=[
215+
RunIf(min_gpus=2),
216+
pytest.mark.skip(reason='TODO: Current issue with multiple optimizers and FairScale.'),
217+
],
218+
),
219+
],
220+
)
221+
def test_ddp_spawn_sharded_plugin(kwargs):
222+
if kwargs['gpus'] > 1:
223+
# TODO: decrease speed diff since only 2 GPUs sharding 2 optimizers
224+
kwargs['max_percent_speed_diff'] = 0.25
225+
plugin_parity_test(**kwargs)

tests/accelerators/__init__.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +0,0 @@
1-
try:
2-
from dtrun.launcher import DDPLauncher
3-
except ImportError:
4-
5-
class DDPLauncher:
6-
7-
def run(cmd_line, **kwargs):
8-
9-
def inner(func):
10-
pass
11-
12-
return inner

tests/accelerators/test_ddp.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
from pytorch_lightning import Trainer
23-
from tests.accelerators import ddp_model, DDPLauncher
23+
from tests.accelerators import ddp_model
2424
from tests.helpers.boring_model import BoringModel
2525
from tests.helpers.runif import RunIf
2626
from tests.utilities.distributed import call_training_script
@@ -71,19 +71,6 @@ def test_multi_gpu_model_ddp_fit_test(tmpdir):
7171
assert out['test_acc'] > 0.7
7272

7373

74-
@RunIf(min_gpus=2)
75-
@DDPLauncher.run(
76-
"--max_epochs [max_epochs] --gpus 2 --accelerator [accelerator]",
77-
max_epochs=["1"],
78-
accelerator=["ddp", "ddp_spawn"]
79-
)
80-
def test_cli_to_pass(tmpdir, args=None):
81-
"""
82-
This test verify we can call function using test_cli name
83-
"""
84-
return '1'
85-
86-
8774
@RunIf(skip_windows=True)
8875
@pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine")
8976
def test_torch_distributed_backend_env_variables(tmpdir):

tests/accelerators/test_multi_nodes_gpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import sys
1616
from unittest import mock
1717

18+
import pytest
1819
import torch
1920

2021
from tests.helpers.runif import RunIf
@@ -28,6 +29,9 @@
2829
from tests.helpers.boring_model import BoringModel # noqa: E402
2930

3031

32+
# TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml)
33+
# use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)`
34+
@pytest.mark.skip("Multi-node testing is currently disabled")
3135
@RunIf(special=True)
3236
def test_logging_sync_dist_true_ddp(tmpdir):
3337
"""
@@ -65,6 +69,9 @@ def validation_step(self, batch, batch_idx):
6569
assert trainer.logged_metrics['bar'] == fake_result
6670

6771

72+
# TODO(Borda): When multi-node tests are re-enabled (.github/workflows/ci_test-mnodes.yml)
73+
# use an environment variable `PL_RUNNING_MULTINODE_TESTS` and set `RunIf(multinode=True)`
74+
@pytest.mark.skip("Multi-node testing is currently disabled")
6875
@RunIf(special=True)
6976
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
7077
def test__validation_step__log(tmpdir):

tests/special_tests.sh

100644100755
Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/bin/bash
12
# Copyright The PyTorch Lightning team.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,32 +12,58 @@
1112
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213
# See the License for the specific language governing permissions and
1314
# limitations under the License.
14-
# Running special tests
1515
set -e
16+
17+
# this environment variable allows special tests to run
1618
export PL_RUNNING_SPECIAL_TESTS=1
17-
DEFAULTS="-m coverage run --source pytorch_lightning --append -m pytest --verbose --capture=no"
18-
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp
19-
python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp
20-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_invalid_deepspeed_defaults_no_precision
21-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward
22-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers
23-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config
24-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_custom_precision_params
25-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_assert_config_zero_offload_disabled
26-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu
27-
python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
28-
python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual
29-
python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual_amp
30-
python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_automatic
31-
python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_with_wrong_balance
32-
python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection
33-
python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp
34-
python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp
35-
python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
36-
python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp
37-
python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_trainer_ddp
38-
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp
39-
python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler
40-
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model
41-
python ${DEFAULTS} tests/checkpointing/test_checkpoint_callback_frequency.py::test_top_k_ddp
42-
nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx
19+
# python arguments
20+
defaults='-m coverage run --source pytorch_lightning --append -m pytest --verbose --capture=no'
21+
22+
# find tests marked as `@RunIf(special=True)`
23+
grep_output=$(grep --recursive --line-number --word-regexp 'tests' 'benchmarks' --regexp 'special=True')
24+
# file paths
25+
files=$(echo "$grep_output" | cut -f1 -d:)
26+
files_arr=($files)
27+
# line numbers
28+
linenos=$(echo "$grep_output" | cut -f2 -d:)
29+
linenos_arr=($linenos)
30+
31+
# tests to skip - space separated
32+
blocklist='test_pytorch_profiler_nested_emit_nvtx'
33+
report=''
34+
35+
for i in "${!files_arr[@]}"; do
36+
file=${files_arr[$i]}
37+
lineno=${linenos_arr[$i]}
38+
39+
# get code from `@RunIf(special=True)` line to EOF
40+
test_code=$(tail -n +"$lineno" "$file")
41+
42+
# read line by line
43+
while read -r line; do
44+
# if it's a test
45+
if [[ $line == def\ test_* ]]; then
46+
# get the name
47+
test_name=$(echo $line | cut -c 5- | cut -f1 -d\()
48+
49+
# check blocklist
50+
if echo $blocklist | grep --word-regexp "$test_name" > /dev/null; then
51+
report+="Skipped\t$file:$lineno::$test_name\n"
52+
break
53+
fi
54+
55+
# run the test
56+
report+="Ran\t$file:$lineno::$test_name\n"
57+
python ${defaults} "${file}::${test_name}"
58+
break
59+
fi
60+
done < <(echo "$test_code")
61+
done
62+
63+
nvprof --profile-from-start off -o trace_name.prof -- python ${defaults} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx
64+
65+
# echo test report
66+
printf '=%.s' {1..80}
67+
printf "\n$report"
68+
printf '=%.s' {1..80}
69+
printf '\n'

tests/utilities/test_all_gather_grad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ class TestModel(BoringModel):
5555
training_epoch_end_called = False
5656

5757
def training_epoch_end(self, outputs) -> None:
58-
self.training_epoch_end_called = True
5958
losses = torch.stack([x["loss"] for x in outputs])
6059
gathered_loss = self.all_gather({
6160
"losses_tensor_int": torch.rand(2, 2).int().t(),
@@ -67,7 +66,7 @@ def training_epoch_end(self, outputs) -> None:
6766
"losses": losses,
6867
"losses_list": [losses, losses]
6968
})
70-
assert gathered_loss["losses_tensor_int"][0].dtype == torch.int64
69+
assert gathered_loss["losses_tensor_int"][0].dtype == torch.int32
7170
assert gathered_loss["losses_tensor_float"][0].dtype == torch.float
7271
assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64
7372
# torch.bool can't be all_gathered
@@ -76,6 +75,7 @@ def training_epoch_end(self, outputs) -> None:
7675
assert gathered_loss["losses_int"][0].dtype == torch.int
7776
assert gathered_loss["losses_list"][0].numel() == 2 * len(losses)
7877
assert gathered_loss["losses"].numel() == 2 * len(losses)
78+
self.training_epoch_end_called = True
7979

8080
seed_everything(42)
8181

@@ -115,6 +115,6 @@ def training_step(self, batch, batch_idx):
115115
return loss
116116

117117
model = TestModel()
118-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2)
118+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, accelerator="ddp")
119119
trainer.fit(model)
120120
assert model.training_step_called

0 commit comments

Comments
 (0)