Skip to content

Commit fff62f0

Browse files
awaelchlicarmoccapre-commit-ci[bot]Bordakaushikb11
authored
Fix TPU testing and collect all tests (#11098)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Kaushik B <[email protected]>
1 parent 95f5f17 commit fff62f0

File tree

23 files changed

+213
-203
lines changed

23 files changed

+213
-203
lines changed

.azure/gpu-tests.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ jobs:
116116
timeoutInMinutes: "35"
117117
condition: eq(variables['continue'], '1')
118118

119+
- bash: bash run_standalone_tasks.sh
120+
workingDirectory: tests/tests_pytorch
121+
env:
122+
PL_USE_MOCKED_MNIST: "1"
123+
PL_RUN_CUDA_TESTS: "1"
124+
displayName: 'Testing: PyTorch standalone tasks'
125+
timeoutInMinutes: "10"
126+
condition: eq(variables['continue'], '1')
127+
119128
- bash: |
120129
python -m coverage report
121130
python -m coverage xml

.circleci/config.yml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ references:
8181
job_name=$(jsonnet -J ml-testing-accelerators/ dockers/tpu-tests/tpu_test_cases.jsonnet | kubectl create -f -) && \
8282
job_name=${job_name#job.batch/}
8383
job_name=${job_name% created}
84+
pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}')
85+
echo "GKE pod name: $pod_name"
8486
echo "Waiting on kubernetes job: $job_name"
8587
i=0 && \
8688
# N checks spaced 30s apart = 900s total.
@@ -92,8 +94,6 @@ references:
9294
printf "Waiting for job to finish: " && \
9395
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "."; fi; sleep $CHECK_SPEEP; done && \
9496
echo "Done waiting. Job status code: $status_code" && \
95-
pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') && \
96-
echo "GKE pod name: $pod_name" && \
9797
kubectl logs -f $pod_name --container=train > /tmp/full_output.txt
9898
if grep -q '<?xml version="1.0" ?>' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '/<?xml version="1.0" ?>/'; else mv /tmp/full_output.txt xx00; fi && \
9999
# First portion is the test logs. Print these to Github Action stdout.
@@ -106,10 +106,6 @@ references:
106106
name: Statistics
107107
command: |
108108
mv ./xx01 coverage.xml
109-
# TODO: add human readable report
110-
cat coverage.xml
111-
sudo pip install pycobertura
112-
pycobertura show coverage.xml
113109
114110
jobs:
115111

@@ -119,7 +115,7 @@ jobs:
119115
environment:
120116
- XLA_VER: 1.9
121117
- PYTHON_VER: 3.7
122-
- MAX_CHECKS: 240
118+
- MAX_CHECKS: 1000
123119
- CHECK_SPEEP: 5
124120
steps:
125121
- checkout

dockers/tpu-tests/tpu_test_cases.jsonnet

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ local tputests = base.BaseTest {
88
mode: 'postsubmit',
99
configMaps: [],
1010

11-
timeout: 1200, # 20 minutes, in seconds.
11+
timeout: 6000, # 100 minutes, in seconds.
1212

1313
image: 'pytorchlightning/pytorch_lightning',
1414
imageTag: 'base-xla-py{PYTHON_VERSION}-torch{PYTORCH_VERSION}',
@@ -34,16 +34,11 @@ local tputests = base.BaseTest {
3434
pip install -e .[test]
3535
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
3636
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
37+
export PL_RUN_TPU_TESTS=1
3738
cd tests/tests_pytorch
38-
echo $PWD
39-
# TODO (@kaushikb11): Add device stats tests here
40-
coverage run --source pytorch_lightning -m pytest -v --capture=no \
41-
strategies/test_tpu_spawn.py \
42-
profilers/test_xla_profiler.py \
43-
accelerators/test_tpu.py \
44-
models/test_tpu.py \
45-
plugins/environments/test_xla_environment.py \
46-
utilities/test_xla_device_utils.py
39+
coverage run --source=pytorch_lightning -m pytest -vv --durations=0 ./
40+
echo "\n||| Running standalone tests |||\n"
41+
bash run_standalone_tests.sh -b 1
4742
test_exit_code=$?
4843
echo "\n||| END PYTEST LOGS |||\n"
4944
coverage xml

src/pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
class SingleTPUPlugin(SingleTPUStrategy):
1919
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
2020
rank_zero_deprecation(
21-
"The `pl.plugins.training_type.single_tpu.SingleTPUPlugin` is deprecated in v1.6 and will be removed in."
21+
"The `pl.plugins.training_type.single_tpu.SingleTPUPlugin` is deprecated in v1.6 and will be removed in"
2222
" v1.8. Use `pl.strategies.single_tpu.SingleTPUStrategy` instead."
2323
)
2424
super().__init__(*args, **kwargs)

src/pytorch_lightning/strategies/launchers/xla.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
# limitations under the License.
1414
import os
1515
import time
16+
from functools import wraps
1617
from multiprocessing.queues import SimpleQueue
17-
from typing import Any, Callable, Optional, TYPE_CHECKING
18+
from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING
1819

1920
import torch.multiprocessing as mp
21+
from torch.multiprocessing import ProcessContext
2022

2123
import pytorch_lightning as pl
2224
from pytorch_lightning.strategies.launchers.multiprocessing import _FakeQueue, _MultiProcessingLauncher, _WorkerOutput
@@ -26,9 +28,10 @@
2628
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
2729

2830
if _TPU_AVAILABLE:
31+
import torch_xla.core.xla_model as xm
2932
import torch_xla.distributed.xla_multiprocessing as xmp
3033
else:
31-
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4
34+
xm, xmp = None, None
3235

3336
if TYPE_CHECKING:
3437
from pytorch_lightning.strategies import Strategy
@@ -72,7 +75,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
7275
"""
7376
context = mp.get_context(self._start_method)
7477
return_queue = context.SimpleQueue()
75-
xmp.spawn(
78+
_save_spawn(
7679
self._wrapping_function,
7780
args=(trainer, function, args, kwargs, return_queue),
7881
nprocs=len(self._strategy.parallel_devices),
@@ -103,14 +106,6 @@ def _wrapping_function(
103106
if self._strategy.local_rank == 0:
104107
return_queue.put(move_data_to_device(results, "cpu"))
105108

106-
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
107-
self._strategy.barrier("end-process")
108-
109-
# Ensure that the rank 0 process is the one exiting last
110-
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
111-
if self._strategy.local_rank == 0:
112-
time.sleep(2)
113-
114109
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
115110
rank_zero_debug("Collecting results from rank 0 process.")
116111
checkpoint_callback = trainer.checkpoint_callback
@@ -138,3 +133,30 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
138133
self.add_to_queue(trainer, extra)
139134

140135
return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
136+
137+
138+
def _save_spawn(
139+
fn: Callable,
140+
args: Tuple = (),
141+
nprocs: Optional[int] = None,
142+
join: bool = True,
143+
daemon: bool = False,
144+
start_method: str = "spawn",
145+
) -> Optional[ProcessContext]:
146+
"""Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker
147+
processes."""
148+
149+
@wraps(fn)
150+
def wrapped(rank: int, *_args: Any) -> None:
151+
fn(rank, *_args)
152+
153+
# Make all processes wait for each other before joining
154+
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
155+
xm.rendezvous("end-process")
156+
157+
# Ensure that the rank 0 process is the one exiting last
158+
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
159+
if rank == 0:
160+
time.sleep(1)
161+
162+
return xmp.spawn(wrapped, args=args, nprocs=nprocs, join=join, daemon=daemon, start_method=start_method)

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
start_method="fork",
7575
)
7676
self.debug = debug
77+
self._launched = False
7778

7879
@property
7980
def checkpoint_io(self) -> CheckpointIO:
@@ -90,6 +91,8 @@ def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
9091

9192
@property
9293
def root_device(self) -> torch.device:
94+
if not self._launched:
95+
raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
9396
return xm.xla_device()
9497

9598
@staticmethod
@@ -130,7 +133,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
130133
self.accelerator.setup(trainer)
131134

132135
if self.debug:
133-
os.environ["PT_XLA_DEBUG"] = str(1)
136+
os.environ["PT_XLA_DEBUG"] = "1"
134137

135138
shared_params = find_shared_parameters(self.model)
136139
self.model_to_device()
@@ -150,8 +153,8 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]:
150153

151154
@property
152155
def is_distributed(self) -> bool:
153-
# HOST_WORLD_SIZE is None outside the xmp.spawn process
154-
return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1
156+
# HOST_WORLD_SIZE is not set outside the xmp.spawn process
157+
return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1
155158

156159
def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
157160
TPUSpawnStrategy._validate_dataloader(dataloader)
@@ -189,8 +192,9 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
189192
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
190193
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
191194
if invalid_reduce_op or invalid_reduce_op_str:
192-
raise MisconfigurationException(
193-
"Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation."
195+
raise ValueError(
196+
"Currently, the TPUSpawnStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
197+
f" {reduce_op}"
194198
)
195199

196200
output = xm.mesh_reduce("reduce", output, sum)
@@ -201,6 +205,7 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
201205
return output
202206

203207
def _worker_setup(self, process_idx: int):
208+
self._launched = True
204209
reset_seed()
205210
self.set_world_ranks(process_idx)
206211
rank_zero_only.rank = self.global_rank

tests/tests_pytorch/accelerators/test_accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def test_devices_auto_choice_mps():
671671

672672
@pytest.mark.parametrize(
673673
["parallel_devices", "accelerator"],
674-
[([torch.device("cpu")], "cuda"), ([torch.device("cuda", i) for i in range(8)], ("tpu"))],
674+
[([torch.device("cpu")], "cuda"), ([torch.device("cuda", i) for i in range(8)], "tpu")],
675675
)
676676
def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_devices, accelerator):
677677
with pytest.raises(MisconfigurationException, match=r"parallel_devices set through"):

tests/tests_pytorch/accelerators/test_ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def test_strategy_choice_ipu_plugin(tmpdir):
602602

603603

604604
@RunIf(ipu=True)
605-
def test_device_type_when_training_plugin_ipu_passed(tmpdir):
605+
def test_device_type_when_ipu_strategy_passed(tmpdir):
606606
trainer = Trainer(strategy=IPUStrategy(), accelerator="ipu", devices=8)
607607
assert isinstance(trainer.strategy, IPUStrategy)
608608
assert isinstance(trainer.accelerator, IPUAccelerator)

tests/tests_pytorch/accelerators/test_tpu.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from pytorch_lightning.strategies import DDPStrategy, TPUSpawnStrategy
2929
from pytorch_lightning.utilities import find_shared_parameters
3030
from tests_pytorch.helpers.runif import RunIf
31-
from tests_pytorch.helpers.utils import pl_multi_process_test
3231

3332

3433
class WeightSharingModule(BoringModel):
@@ -46,8 +45,7 @@ def forward(self, x):
4645
return x
4746

4847

49-
@RunIf(tpu=True)
50-
@pl_multi_process_test
48+
@RunIf(tpu=True, standalone=True)
5149
def test_resume_training_on_cpu(tmpdir):
5250
"""Checks if training can be resumed from a saved checkpoint on CPU."""
5351
# Train a model on TPU
@@ -65,11 +63,9 @@ def test_resume_training_on_cpu(tmpdir):
6563
# Verify that training is resumed on CPU
6664
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
6765
trainer.fit(model, ckpt_path=model_path)
68-
assert trainer.state.finished, f"Training failed with {trainer.state}"
6966

7067

7168
@RunIf(tpu=True)
72-
@pl_multi_process_test
7369
def test_if_test_works_after_train(tmpdir):
7470
"""Ensure that .test() works after .fit()"""
7571

@@ -293,12 +289,14 @@ def test_xla_checkpoint_plugin_being_default():
293289
assert isinstance(trainer.strategy.checkpoint_io, XLACheckpointIO)
294290

295291

296-
@RunIf(tpu=True)
297-
@patch("pytorch_lightning.strategies.tpu_spawn.xm")
298-
def test_mp_device_dataloader_attribute(_):
292+
@patch("pytorch_lightning.strategies.tpu_spawn.MpDeviceLoader")
293+
@patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.root_device")
294+
def test_mp_device_dataloader_attribute(root_device_mock, mp_loader_mock):
299295
dataset = RandomDataset(32, 64)
300-
dataloader = TPUSpawnStrategy().process_dataloader(DataLoader(dataset))
301-
assert dataloader.dataset == dataset
296+
dataloader = DataLoader(dataset)
297+
processed_dataloader = TPUSpawnStrategy().process_dataloader(dataloader)
298+
mp_loader_mock.assert_called_with(dataloader, root_device_mock)
299+
assert processed_dataloader.dataset == processed_dataloader._loader.dataset
302300

303301

304302
@RunIf(tpu=True)
@@ -307,8 +305,7 @@ def test_warning_if_tpus_not_used():
307305
Trainer()
308306

309307

310-
@RunIf(tpu=True)
311-
@pl_multi_process_test
308+
@RunIf(tpu=True, standalone=True)
312309
@pytest.mark.parametrize(
313310
["devices", "expected_device_ids"],
314311
[

tests/tests_pytorch/callbacks/test_device_stats_monitor.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
9696
assert cpu_stats_mock.call_count == expected
9797

9898

99-
@pytest.mark.skipif(True, reason="TODO (@kaushikb11): fix this test, timeout")
10099
@RunIf(tpu=True)
101100
def test_device_stats_monitor_tpu(tmpdir):
102101
"""Test TPU stats are logged using a logger."""
@@ -106,24 +105,23 @@ def test_device_stats_monitor_tpu(tmpdir):
106105

107106
class DebugLogger(CSVLogger):
108107
@rank_zero_only
109-
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
108+
def log_metrics(self, metrics, step=None) -> None:
110109
fields = ["avg. free memory (MB)", "avg. peak memory (MB)"]
111110
for f in fields:
112111
assert any(f in h for h in metrics)
113112

114113
trainer = Trainer(
115114
default_root_dir=tmpdir,
116-
max_epochs=1,
117-
limit_train_batches=2,
115+
max_epochs=2,
116+
limit_train_batches=5,
118117
accelerator="tpu",
119-
devices=1,
118+
devices=8,
120119
log_every_n_steps=1,
121120
callbacks=[device_stats],
122121
logger=DebugLogger(tmpdir),
123122
enable_checkpointing=False,
124123
enable_progress_bar=False,
125124
)
126-
127125
trainer.fit(model)
128126

129127

@@ -146,7 +144,7 @@ def test_device_stats_monitor_no_logger(tmpdir):
146144
trainer.fit(model)
147145

148146

149-
def test_prefix_metric_keys(tmpdir):
147+
def test_prefix_metric_keys():
150148
"""Test that metric key names are converted correctly."""
151149
metrics = {"1": 1.0, "2": 2.0, "3": 3.0}
152150
prefix = "foo"

0 commit comments

Comments
 (0)