Skip to content

Improve docs to explain DeepSpeedPlugin with activation checkpointing #9152

@SeanNaren

Description

@SeanNaren

Discussed in #9144

Originally posted by nachshonc August 26, 2021
Hi there,
I'm trying to run pytorch-lightning training with deepspeed plugin and activation checkpoints to support bigger batch sizes, based on https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#deepspeed-activation-checkpointing.
As specified in the docs, running the model should be done using the checkpoint function. However, this function seems to return a tensor without gradients. When computing loss based on this value and returning from training_step, I'm getting an error
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Minimal code to reproduce

import os

import deepspeed
import pytorch_lightning as pl
import torch
from deepspeed.ops.adam import FusedAdam
from pytorch_lightning.plugins import DeepSpeedPlugin
from torch import nn
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.utils.data import DataLoader, RandomSampler


class PlModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(1, 1)

    def forward(self, batch):
        return self.model(batch)

    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        res = deepspeed.checkpointing.checkpoint(self.model, batch)
        return nn.MSELoss()(res, torch.zeros_like(res, device=res.device))

    def configure_optimizers(self):
        return FusedAdam(self.parameters(), lr=0.1)


if __name__ == '__main__':
    trainer = pl.Trainer(gpus=-1, precision=16, plugins=DeepSpeedPlugin(stage=3, partition_activations=True))
    model = PlModel()
    dataset = torch.rand(100, 1)
    dl = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=os.cpu_count(),
                                     sampler=RandomSampler(dataset))
    trainer.fit(model, dl)

pytorch-lightning version: 1.3.3
deepspeed version: 0.5.0
Thanks!

Metadata

Metadata

Assignees

Labels

docsDocumentation related

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions