Skip to content

[RFC] Simplify sharding API instantiation #9375

@tchaton

Description

@tchaton

🚀 Feature

Currently, the Lightning users working with sharded models and with DeepSpeed or FSDP Plugin needs to know about configure_sharded_model as follow:

Here is an example.

    class MyModel(pl.LightningModule):
        def __init__(self):
            super().__init__()
            self.linear_layer = nn.Linear(32, 32)
            self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
            self.final_block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())

        def configure_sharded_model(self):
            # modules are sharded across processes
            # as soon as they are wrapped with ``wrap`` or ``auto_wrap``.
            # During the forward/backward passes, weights get synced across processes
            # and de-allocated once computation is complete, saving memory.

            # Wraps the layer in a Fully Sharded Wrapper automatically
            linear_layer = wrap(self.linear_layer)

            # Wraps the module recursively
            # based on a minimum number of parameters (default 100M parameters)
            block = auto_wrap(self.block)

            # For best memory efficiency,
            # add FairScale activation checkpointing
            final_block = auto_wrap(checkpoint_wrapper(self.final_block))
            self.model = nn.Sequential(linear_layer, nn.ReLU(), block, final_block)

        def configure_optimizers(self):
            return torch.optim.AdamW(self.model.parameters())

It could be possibly to uniformize the API using the new meta device introduced by PyTorch 1.9.0: https://pytorch.org/tutorials/prototype/skip_param_init.html

from pytorch_lightning.distributed import skip_param_init, apply_param_init

class Model(LightningModule):

    def __init__(self):
        super().__init__()

        with skip_param_init():
            # create parameters as before, possibly ones which don't fit on 1 gpu


    def setup(self, ...):
        ...


# This model doesn't have any params for module defined within skip_param_init context manager
model = Model() #

# In DeepSpeed / FSDP, when distributed is fully available, `apply_param_init` is applied to create the sharded weights.

Motivation

Pitch

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Metadata

Metadata

Assignees

No one assigned

    Labels

    designIncludes a design discussionfeatureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions