-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Currently, the Lightning users working with sharded models and with DeepSpeed or FSDP Plugin needs to know about configure_sharded_model as follow:
Here is an example.
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.linear_layer = nn.Linear(32, 32)
self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
self.final_block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
def configure_sharded_model(self):
# modules are sharded across processes
# as soon as they are wrapped with ``wrap`` or ``auto_wrap``.
# During the forward/backward passes, weights get synced across processes
# and de-allocated once computation is complete, saving memory.
# Wraps the layer in a Fully Sharded Wrapper automatically
linear_layer = wrap(self.linear_layer)
# Wraps the module recursively
# based on a minimum number of parameters (default 100M parameters)
block = auto_wrap(self.block)
# For best memory efficiency,
# add FairScale activation checkpointing
final_block = auto_wrap(checkpoint_wrapper(self.final_block))
self.model = nn.Sequential(linear_layer, nn.ReLU(), block, final_block)
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters())It could be possibly to uniformize the API using the new meta device introduced by PyTorch 1.9.0: https://pytorch.org/tutorials/prototype/skip_param_init.html
from pytorch_lightning.distributed import skip_param_init, apply_param_init
class Model(LightningModule):
def __init__(self):
super().__init__()
with skip_param_init():
# create parameters as before, possibly ones which don't fit on 1 gpu
def setup(self, ...):
...
# This model doesn't have any params for module defined within skip_param_init context manager
model = Model() #
# In DeepSpeed / FSDP, when distributed is fully available, `apply_param_init` is applied to create the sharded weights.Motivation
Pitch
Alternatives
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
-
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
-
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.