-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
Graceful termination is important, because sometimes we want to stop a lengthy training process which doesn't seem to progress further but still keep all the results. If the data is not saved on each epoch or every few steps than we have to rely on a reliable graceful termination and make sure that on_train_end is executed properly.
When I use DataLoader with num_workers > 0 then sometimes Ctrl+C does not lead to correct graceful termination. My guess is that Ctrl+C is not caught in the main thread and leads to normal process termination.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
# Thanks to @rain1024 for majority of the code
# https://gist.github.com/rain1024/8ea4c2f56aa4c9ba0e1cbf35edb68eca
import pytorch_lightning as pl
import numpy as np
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from pytorch_lightning.callbacks import Callback
class SimpleDataset(Dataset):
def __init__(self):
X = np.arange(10000)
y = X * 2
X = [[_] for _ in X]
y = [[_] for _ in y]
self.X = torch.Tensor(X)
self.y = torch.Tensor(y)
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
return {"X": self.X[idx], "y": self.y[idx]}
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
self.criterion = MSELoss()
def forward(self, inputs_id, labels=None):
outputs = self.fc(inputs_id)
loss = 0
if labels is not None:
loss = self.criterion(outputs, labels)
return loss, outputs
def train_dataloader(self):
dataset = SimpleDataset()
return DataLoader(dataset, batch_size=100, num_workers=10)
def training_step(self, batch, batch_idx):
input_ids = batch["X"]
labels = batch["y"]
loss, outputs = self(input_ids, labels)
return {"loss": loss}
def configure_optimizers(self):
optimizer = Adam(self.parameters())
return optimizer
class MyTearDownCallback(Callback):
def on_train_end(self, trainer, pl_module):
print('Tear down procedure here...')
if __name__ == '__main__':
model = MyModel()
trainer = pl.Trainer(max_epochs=10, callbacks=[MyTearDownCallback()])
trainer.fit(model)
X = torch.Tensor([[1.0], [51.0], [89.0]])
_, y = model(X)
print(y)Run the following code and then press Ctrl+C during training.
The bug is not reproducible 100% of the time. Sometimes, Ctrl+C seems to get caught on the main thread and sometimes if it's caught in the wrong one, a secondary Ctrl+C might help and it correctly exists. It is still however, not reliable enough.
Error messages and logs
(.env) ➜ graceful-termination python graceful-termination-on-background-threads.py
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` ha
s been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, u
nless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
warning_cache.warn(
You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off pre
cision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
--------------------------------------
0 | fc | Linear | 2
1 | criterion | MSELoss | 0
--------------------------------------
2 Trainable params
0 Non-trainable params
2 Total params
0.000 Total estimated model params size (MB)
Epoch 1: 0%| | 0/100 [00:00<?, ?it/s, v_num=11]
^CException ignored in: <function _releaseLock at 0x7fdb522629e0>
Traceback (most recent call last):
File "/usr/lib/python3.10/logging/__init__.py", line 228, in _releaseLock
def _releaseLock():
KeyboardInterrupt:
Traceback (most recent call last):
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/usr/lib/python3.10/multiprocessing/queues.py", line 114, in get
raise Empty
_queue.Empty
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/mzurad/code/pl-bugs/graceful-termination/graceful-termination-on-background-threads.py", line 62, in <module>
trainer.fit(model)
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 531, in fit
call._call_and_handle_interrupt(
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 570, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 975, in _run
results = self._run_stage()
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1018, in _run_stage
self.fit_loop.run()
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 201, in run
self.advance()
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 354, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 133, in run
self.advance(data_fetcher)
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 189, in advance
batch = next(data_fetcher)
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 136, in __next__
self._fetch_next_batch(self.dataloader_iter)
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 150, in _fetch_next_batch
batch = next(iterator)
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 284, in __next__
out = next(self._iterator)
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 65, in __next__
out[i] = next(self.iterators[i])
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
data = self._next_data()
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data
idx, data = self._get_data()
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1294, in _get_data
success, data = self._try_get_data()
File "/home/mzurad/code/pl-bugs/graceful-termination/.env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1145, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 3130238, 3130254, 3130270) exited unexpectedly
Epoch 1: 0%| | 0/100 [00:05<?, ?it/s, v_num=11] Environment
Current environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 3070 Laptop GPU
- available: True
- version: 11.7
- GPU:
- Lightning:
- lightning-utilities: 0.8.0
- pytorch-lightning: 2.0.3
- torch: 2.0.1
- torchmetrics: 0.11.4
- Packages:
- aiohttp: 3.8.4
- aiosignal: 1.3.1
- async-timeout: 4.0.2
- attrs: 23.1.0
- certifi: 2023.5.7
- charset-normalizer: 3.1.0
- cmake: 3.26.4
- filelock: 3.12.2
- frozenlist: 1.3.3
- fsspec: 2023.6.0
- idna: 3.4
- jinja2: 3.1.2
- lightning-utilities: 0.8.0
- lit: 16.0.6
- markupsafe: 2.1.3
- mpmath: 1.3.0
- multidict: 6.0.4
- networkx: 3.1
- numpy: 1.25.0
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.2.10.91
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-nccl-cu11: 2.14.3
- nvidia-nvtx-cu11: 11.7.91
- packaging: 23.1
- pip: 22.0.2
- pytorch-lightning: 2.0.3
- pyyaml: 6.0
- requests: 2.31.0
- setuptools: 59.6.0
- sympy: 1.12
- torch: 2.0.1
- torchmetrics: 0.11.4
- tqdm: 4.65.0
- triton: 2.0.0
- typing-extensions: 4.6.3
- urllib3: 2.0.3
- wheel: 0.40.0
- yarl: 1.9.2
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.6
- release: 5.19.0-43-generic
- version: Extend CI #44~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Mon May 22 13:39:36 UTC 2
More info
No response