-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Closed
Copy link
Labels
checkpointingRelated to checkpointingRelated to checkpointingfabriclightning.fabric.Fabriclightning.fabric.FabricquestionFurther information is requestedFurther information is requestedver: 2.0.xworking as intendedWorking as intendedWorking as intended
Description
Bug description
When torch.load() is used on a checkpoint file instead of fabric.load(), a process is spawned which takes up GPU memory but is not used.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
import argparse
import os.path
import time
import lightning
import torch
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(1000, 10)
def forward(self, x):
return self.fc1(x)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--devices", default=2)
parser.add_argument("--bug", action="store_true")
args = parser.parse_args()
fabric = lightning.Fabric(accelerator="gpu", devices=args.devices, strategy="ddp")
fabric.launch()
network = Network()
network, opt = fabric.setup(network, torch.optim.Adam(network.parameters()))
produce_bug = args.bug
if os.path.exists("network.ckpt"):
if produce_bug:
ckpt = torch.load("network.ckpt")
else:
ckpt = fabric.load("network.ckpt")
network.load_state_dict(ckpt["network"])
else:
if fabric.is_global_zero:
fabric.save("network.ckpt", {"network": network.state_dict()})
fabric.print("network.ckpt now exists, run this script again.")
exit()
while True:
fabric.print("sleeping")
time.sleep(1)
if __name__ == "__main__":
main()When you run this script with python main.py --devices 2, 2 processes will show up in nvidia-smi, while running the script with the --bug flag, there will be 3 processes.
Error messages and logs
No response
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA RTX A5000
- NVIDIA RTX A5000
- available: True
- version: 11.7 - Lightning:
- lightning: 2.0.2
- lightning-cloud: 0.5.36
- lightning-utilities: 0.8.0
- pytorch-lightning: 2.0.2
- torch: 2.0.1
- torchmetrics: 0.11.4 - Packages:
- aiohttp: 3.8.4
- aiosignal: 1.3.1
- anyio: 3.7.0
- arrow: 1.2.3
- async-timeout: 4.0.2
- attrs: 23.1.0
- beautifulsoup4: 4.12.2
- blessed: 1.20.0
- certifi: 2023.5.7
- charset-normalizer: 3.1.0
- click: 8.1.3
- cmake: 3.26.3
- croniter: 1.3.15
- dateutils: 0.6.12
- deepdiff: 6.3.0
- exceptiongroup: 1.1.1
- fastapi: 0.88.0
- filelock: 3.12.0
- frozenlist: 1.3.3
- fsspec: 2023.5.0
- h11: 0.14.0
- idna: 3.4
- inquirer: 3.1.3
- itsdangerous: 2.1.2
- jinja2: 3.1.2
- lightning: 2.0.2
- lightning-cloud: 0.5.36
- lightning-utilities: 0.8.0
- lit: 16.0.5
- markdown-it-py: 2.2.0
- markupsafe: 2.1.2
- mdurl: 0.1.2
- mpmath: 1.3.0
- multidict: 6.0.4
- networkx: 3.1
- numpy: 1.24.3
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.2.10.91
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-nccl-cu11: 2.14.3
- nvidia-nvtx-cu11: 11.7.91
- ordered-set: 4.1.0
- packaging: 23.1
- pip: 23.1.2
- psutil: 5.9.5
- pydantic: 1.10.8
- pygments: 2.15.1
- pyjwt: 2.7.0
- python-dateutil: 2.8.2
- python-editor: 1.0.4
- python-multipart: 0.0.6
- pytorch-lightning: 2.0.2
- pytz: 2023.3
- pyyaml: 6.0
- readchar: 4.0.5
- requests: 2.31.0
- rich: 13.4.1
- setuptools: 59.6.0
- six: 1.16.0
- sniffio: 1.3.0
- soupsieve: 2.4.1
- starlette: 0.22.0
- starsessions: 1.3.0
- sympy: 1.12
- torch: 2.0.1
- torchmetrics: 0.11.4
- tqdm: 4.65.0
- traitlets: 5.9.0
- triton: 2.0.0
- typing-extensions: 4.6.2
- urllib3: 2.0.2
- uvicorn: 0.22.0
- wcwidth: 0.2.6
- websocket-client: 1.5.2
- websockets: 11.0.3
- wheel: 0.40.0
- yarl: 1.9.2 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.6
- release: 5.19.0-38-generic
- version: Cut-out examples #39~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 17 21:16:15 UTC 2
More info
It is the user's mistake to use torch.load instead of fabric.load, but it was tricky to debug... Perhaps there should be a warning, just like when fabric detects model.forward() is not being called?
PossibleUserWarning: You are calling `torch.load` after fabric.launch(). This will bypass the fabric strategy and result in incorrect behavior. You should use `fabric.load` instead.
Metadata
Metadata
Assignees
Labels
checkpointingRelated to checkpointingRelated to checkpointingfabriclightning.fabric.Fabriclightning.fabric.FabricquestionFurther information is requestedFurther information is requestedver: 2.0.xworking as intendedWorking as intendedWorking as intended