Skip to content

Actor hangs with ray using PL v>=1.7 #14292

@erezinman

Description

@erezinman

When invoking a training from within a Ray actor, the training process hangs. The reason for this is that Ray hangs (happens on multiple Ubuntu machines) when forking a process the way you did when calling device_parser.num_cuda_devices. This doesn't happen on v1.6.5.

To sum-up (before delving into the details), the hanging occurs because of the following change between versions:
In v1.6.5 the gpu (changed to cuda in the newer version) accelerator just calls torch.cuda.device_count() > 0 directly, while in 1.7.0 and above, this function torch.cuda.device_count is called in a forked process (utilities/device_parser.py, line 346, in num_cuda_devices):

with multiprocessing.get_context("fork").Pool(1) as pool:
    return pool.apply(torch.cuda.device_count)

The way I deduced the problem is by listening to system-signals and printing the traceback when getting SIGUSR1 within the hanged process. The traceback we obtained was (starting after calling pl.Trainer(...))

File "<__CONDA__ENV__>/lib/python3.9/site-packages/pytorch_lightning/utilities/argparse.py", line 345, in insert_env_defaults
  return fn(self, **kwargs)
File "<__CONDA__ENV__>/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 433, in __init__
  self._accelerator_connector = AcceleratorConnector(
File "<__CONDA__ENV__>/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 212, in __init__
  self._set_parallel_devices_and_init_accelerator()
File "<__CONDA__ENV__>/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 525, in _set_parallel_devices_and_init_accelerator
  if not self.accelerator.is_available():
File "<__CONDA__ENV__>/lib/python3.9/site-packages/pytorch_lightning/accelerators/cuda.py", line 91, in is_available
  return device_parser.num_cuda_devices() > 0
File "<__CONDA__ENV__>/lib/python3.9/site-packages/pytorch_lightning/utilities/device_parser.py", line 346, in num_cuda_devices
  return pool.apply(torch.cuda.device_count)
File "<__CONDA__ENV__>/lib/python3.9/multiprocessing/pool.py", line 736, in __exit__
  self.terminate()
File "<__CONDA__ENV__>/lib/python3.9/multiprocessing/pool.py", line 654, in terminate
  self._terminate()
File "<__CONDA__ENV__>/lib/python3.9/multiprocessing/util.py", line 224, in __call__
  res = self._callback(*self._args, **self._kwargs)
File "<__CONDA__ENV__>/lib/python3.9/multiprocessing/pool.py", line 729, in _terminate_pool
  p.join()
File "<__CONDA__ENV__>/lib/python3.9/multiprocessing/process.py", line 149, in join
  res = self._popen.wait(timeout)
File "<__CONDA__ENV__>/lib/python3.9/multiprocessing/popen_fork.py", line 43, in wait
  return self.poll(os.WNOHANG if timeout == 0.0 else 0)
File "<__CONDA__ENV__>/lib/python3.9/multiprocessing/popen_fork.py", line 27, in poll
  pid, sts = os.waitpid(self.pid, flag)

The hanging is reproduced by forking from within a Ray worker:

import ray
import multiprocessing

@ray.remote
def f():
    with multiprocessing.get_context("fork").Pool(1) as pool:
        return pool.apply(int)

print(ray.get(f.remote()))

Thanks!

cc @tchaton @rohitgr7 @justusschock @kaushikb11 @awaelchli @akihironitta

Metadata

Metadata

Assignees

Labels

3rd partyRelated to a 3rd-partybugSomething isn't workingpriority: 0High priority taskstrategy: ddpDistributedDataParallel

Type

No type

Projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions