Skip to content

Using torch.load instead of fabric.load spawns a zombie process when using DDP #17737

@nikvaessen

Description

@nikvaessen

Bug description

When torch.load() is used on a checkpoint file instead of fabric.load(), a process is spawned which takes up GPU memory but is not used.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import argparse
import os.path
import time

import lightning
import torch


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = torch.nn.Linear(1000, 10)

    def forward(self, x):
        return self.fc1(x)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--devices", default=2)
    parser.add_argument("--bug", action="store_true")
    args = parser.parse_args()

    fabric = lightning.Fabric(accelerator="gpu", devices=args.devices, strategy="ddp")

    fabric.launch()

    network = Network()
    network, opt = fabric.setup(network, torch.optim.Adam(network.parameters()))

    produce_bug = args.bug
    if os.path.exists("network.ckpt"):
        if produce_bug:
            ckpt = torch.load("network.ckpt")
        else:
            ckpt = fabric.load("network.ckpt")

        network.load_state_dict(ckpt["network"])
    else:
        if fabric.is_global_zero:
            fabric.save("network.ckpt", {"network": network.state_dict()})
            fabric.print("network.ckpt now exists, run this script again.")
        exit()

    while True:
        fabric.print("sleeping")
        time.sleep(1)


if __name__ == "__main__":
    main()

When you run this script with python main.py --devices 2, 2 processes will show up in nvidia-smi, while running the script with the --bug flag, there will be 3 processes.

Error messages and logs

No response

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA RTX A5000
    - NVIDIA RTX A5000
    - available: True
    - version: 11.7
  • Lightning:
    - lightning: 2.0.2
    - lightning-cloud: 0.5.36
    - lightning-utilities: 0.8.0
    - pytorch-lightning: 2.0.2
    - torch: 2.0.1
    - torchmetrics: 0.11.4
  • Packages:
    - aiohttp: 3.8.4
    - aiosignal: 1.3.1
    - anyio: 3.7.0
    - arrow: 1.2.3
    - async-timeout: 4.0.2
    - attrs: 23.1.0
    - beautifulsoup4: 4.12.2
    - blessed: 1.20.0
    - certifi: 2023.5.7
    - charset-normalizer: 3.1.0
    - click: 8.1.3
    - cmake: 3.26.3
    - croniter: 1.3.15
    - dateutils: 0.6.12
    - deepdiff: 6.3.0
    - exceptiongroup: 1.1.1
    - fastapi: 0.88.0
    - filelock: 3.12.0
    - frozenlist: 1.3.3
    - fsspec: 2023.5.0
    - h11: 0.14.0
    - idna: 3.4
    - inquirer: 3.1.3
    - itsdangerous: 2.1.2
    - jinja2: 3.1.2
    - lightning: 2.0.2
    - lightning-cloud: 0.5.36
    - lightning-utilities: 0.8.0
    - lit: 16.0.5
    - markdown-it-py: 2.2.0
    - markupsafe: 2.1.2
    - mdurl: 0.1.2
    - mpmath: 1.3.0
    - multidict: 6.0.4
    - networkx: 3.1
    - numpy: 1.24.3
    - nvidia-cublas-cu11: 11.10.3.66
    - nvidia-cuda-cupti-cu11: 11.7.101
    - nvidia-cuda-nvrtc-cu11: 11.7.99
    - nvidia-cuda-runtime-cu11: 11.7.99
    - nvidia-cudnn-cu11: 8.5.0.96
    - nvidia-cufft-cu11: 10.9.0.58
    - nvidia-curand-cu11: 10.2.10.91
    - nvidia-cusolver-cu11: 11.4.0.1
    - nvidia-cusparse-cu11: 11.7.4.91
    - nvidia-nccl-cu11: 2.14.3
    - nvidia-nvtx-cu11: 11.7.91
    - ordered-set: 4.1.0
    - packaging: 23.1
    - pip: 23.1.2
    - psutil: 5.9.5
    - pydantic: 1.10.8
    - pygments: 2.15.1
    - pyjwt: 2.7.0
    - python-dateutil: 2.8.2
    - python-editor: 1.0.4
    - python-multipart: 0.0.6
    - pytorch-lightning: 2.0.2
    - pytz: 2023.3
    - pyyaml: 6.0
    - readchar: 4.0.5
    - requests: 2.31.0
    - rich: 13.4.1
    - setuptools: 59.6.0
    - six: 1.16.0
    - sniffio: 1.3.0
    - soupsieve: 2.4.1
    - starlette: 0.22.0
    - starsessions: 1.3.0
    - sympy: 1.12
    - torch: 2.0.1
    - torchmetrics: 0.11.4
    - tqdm: 4.65.0
    - traitlets: 5.9.0
    - triton: 2.0.0
    - typing-extensions: 4.6.2
    - urllib3: 2.0.2
    - uvicorn: 0.22.0
    - wcwidth: 0.2.6
    - websocket-client: 1.5.2
    - websockets: 11.0.3
    - wheel: 0.40.0
    - yarl: 1.9.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.6
    - release: 5.19.0-38-generic
    - version: Cut-out examples #39~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 17 21:16:15 UTC 2

More info

It is the user's mistake to use torch.load instead of fabric.load, but it was tricky to debug... Perhaps there should be a warning, just like when fabric detects model.forward() is not being called?

PossibleUserWarning: You are calling `torch.load` after fabric.launch(). This will bypass the fabric strategy and result in incorrect behavior. You should use `fabric.load` instead.

cc @awaelchli @carmocca @justusschock

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions