-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add LightningLite Example #9991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
a96ff63
update
tchaton 617f638
update
tchaton 1680039
Merge branch 'lite-poc' into testing_lite
tchaton c303bee
update
tchaton 4c61e44
Merge branch 'testing_lite' of https://github.com/PyTorchLightning/py…
tchaton 13f4686
Update pl_examples/lite_examples/pytorch_2_lite_2_lightning.py
kaushikb11 cbcf7b5
update on comments
tchaton 6db1cce
update on comments
tchaton a20fcc1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1618c1f
typo
tchaton 4b23cb1
Merge branch 'testing_lite' of https://github.com/PyTorchLightning/py…
tchaton 9461d1b
update
tchaton 6d88c4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] cbd4baf
update
tchaton 26de692
update
tchaton 1cba3f4
update
tchaton e821d95
update
tchaton 20d7ab6
update
tchaton 5ec4040
Merge branch 'lite-poc' into testing_lite
tchaton d8a22c4
update
tchaton 04dad36
Merge branch 'testing_lite' of https://github.com/PyTorchLightning/py…
tchaton File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
246 changes: 246 additions & 0 deletions
246
pl_examples/lite_examples/pytorch_2_lite_2_lightning.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,246 @@ | ||
| # Copyright The PyTorch Lightning team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import torch | ||
| from torch import nn | ||
| from torch.utils.data import DataLoader, Dataset | ||
|
|
||
| from pytorch_lightning import seed_everything | ||
| from pytorch_lightning.lite import LightningLite | ||
|
|
||
| ############################################################################################# | ||
| # Section 1: PyTorch to Lightning Lite # | ||
| # # | ||
| # What is LightningLite ? # | ||
| # # | ||
| # `LightningLite` is a python class you can override to get access to Lightning # | ||
| # accelerators and scale your training, but furthermore, it is intended to be the safest # | ||
| # route to fully transition to Lightning. # | ||
| # # | ||
| # Does LightningLite requires code changes ? # | ||
| # # | ||
| # `LightningLite` code changes are minimal and this tutorial will show you how easy it # | ||
| # is to convert to `lite` using a `BoringModel`. # | ||
| # # | ||
| ############################################################################################# | ||
|
|
||
| ############################################################################################# | ||
| # Pure PyTorch Section # | ||
| ############################################################################################# | ||
|
|
||
|
|
||
| # 1 / 6: Implement a `BoringModel` with only one layer. | ||
| class BoringModel(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.layer = torch.nn.Linear(32, 2) | ||
|
|
||
| def forward(self, x): | ||
| x = self.layer(x) | ||
| return torch.nn.functional.mse_loss(x, torch.ones_like(x)) | ||
|
|
||
|
|
||
| # 2 / 6: Implement a `configure_optimizers` taking a module and returning an optimizer. | ||
| def configure_optimizers(module: nn.Module): | ||
| return torch.optim.SGD(module.parameters(), lr=0.001) | ||
|
|
||
|
|
||
| # 3 / 6: Implement a simple dataset returning random data with the specified shape. | ||
| class RandomDataset(Dataset): | ||
| def __init__(self, length: int, size: int): | ||
| self.len = length | ||
| self.data = torch.randn(length, size) | ||
|
|
||
| def __getitem__(self, index): | ||
| return self.data[index] | ||
|
|
||
| def __len__(self): | ||
| return self.len | ||
|
|
||
|
|
||
| # 4 / 6: Implement the functions to create the dataloaders. | ||
| def train_dataloader(): | ||
| return DataLoader(RandomDataset(64, 32)) | ||
|
|
||
|
|
||
| def val_dataloader(): | ||
| return DataLoader(RandomDataset(64, 32)) | ||
|
|
||
|
|
||
| # 5 / 6: Our main PyTorch Loop to train our `BoringModel` on our random data. | ||
| def main(model: nn.Module, train_dataloader: DataLoader, val_dataloader: DataLoader, num_epochs: int = 10): | ||
| optimizer = configure_optimizers(model) | ||
|
|
||
| for epoch in range(num_epochs): | ||
| train_losses = [] | ||
| val_losses = [] | ||
|
|
||
| model.train() | ||
| for batch in train_dataloader: | ||
| optimizer.zero_grad() | ||
| loss = model(batch) | ||
| train_losses.append(loss) | ||
| loss.backward() | ||
| optimizer.step() | ||
|
|
||
| model.eval() | ||
| with torch.no_grad(): | ||
| for batch in val_dataloader: | ||
| val_losses.append(model(batch)) | ||
|
|
||
| train_epoch_loss = torch.stack(train_losses).mean() | ||
| val_epoch_loss = torch.stack(val_losses).mean() | ||
|
|
||
| print(f"{epoch}/{num_epochs}| Train Epoch Loss: {torch.mean(train_epoch_loss)}") | ||
| print(f"{epoch}/{num_epochs}| Valid Epoch Loss: {torch.mean(val_epoch_loss)}") | ||
|
|
||
| return model.state_dict() | ||
|
|
||
|
|
||
| # 6 / 6: Run the pure PyTorch Loop and train / validate the model. | ||
| seed_everything(42) | ||
| model = BoringModel() | ||
| pure_model_weights = main(model, train_dataloader(), val_dataloader()) | ||
|
|
||
|
|
||
| ############################################################################################# | ||
| # Convert to LightningLite # | ||
| # # | ||
| # By converting to `LightningLite`, you get the full power of Lightning accelerators # | ||
| # while conversing your original code ! # | ||
| # To get started, you would need to `from pytorch_lightning.lite import LightningLite` # | ||
| # and override its `run` method. # | ||
| ############################################################################################# | ||
|
|
||
|
|
||
| class LiteTrainer(LightningLite): | ||
| def run(self, model: nn.Module, train_dataloader: DataLoader, val_dataloader: DataLoader, num_epochs: int = 10): | ||
| optimizer = configure_optimizers(model) | ||
|
|
||
| ################################################################################### | ||
| # You would need to call `self.setup` to wrap `model` and `optimizer`. If you # | ||
| # have multiple models (c.f GAN), call `setup` for each one of them and their # | ||
| # associated optimizers. # | ||
| model, optimizer = self.setup(model=model, optimizers=optimizer) | ||
| ################################################################################### | ||
|
|
||
| ################################################################################### | ||
| # You would need to call `self.setup_dataloaders` to prepare the dataloaders # | ||
| # in case you are running in a distributed setting. # | ||
| train_dataloader = self.setup_dataloaders(train_dataloader) | ||
| val_dataloader = self.setup_dataloaders(val_dataloader) | ||
| ################################################################################### | ||
|
|
||
| for epoch in range(num_epochs): | ||
| train_losses = [] | ||
| val_losses = [] | ||
|
|
||
| model.train() | ||
| for batch in train_dataloader: | ||
| optimizer.zero_grad() | ||
| loss = model(batch) | ||
| train_losses.append(loss) | ||
| ########################################################################### | ||
| # By calling `self.backward` directly, `LightningLite` will automate # | ||
| # precision and distributions. # | ||
| self.backward(loss) | ||
| ########################################################################### | ||
| optimizer.step() | ||
|
|
||
| model.eval() | ||
| with torch.no_grad(): | ||
| for batch in val_dataloader: | ||
| val_losses.append(model(batch)) | ||
|
|
||
| train_epoch_loss = torch.stack(train_losses).mean() | ||
| val_epoch_loss = torch.stack(val_losses).mean() | ||
|
|
||
| ################################################################################ | ||
| # Optional: Utility to print only on rank 0 (when using distributed setting) # | ||
| self.print(f"{epoch}/{num_epochs}| Train Epoch Loss: {train_epoch_loss}") | ||
| self.print(f"{epoch}/{num_epochs}| Valid Epoch Loss: {val_epoch_loss}") | ||
| ################################################################################ | ||
|
|
||
|
|
||
| seed_everything(42) | ||
| lite_model = BoringModel() | ||
| lite = LiteTrainer() | ||
| lite.run(lite_model, train_dataloader(), val_dataloader()) | ||
|
|
||
| ############################################################################################# | ||
| # Assert the weights are the same # | ||
| ############################################################################################# | ||
|
|
||
| for pure_w, lite_w in zip(pure_model_weights.values(), lite_model.state_dict().values()): | ||
| torch.equal(pure_w, lite_w) | ||
|
|
||
|
|
||
| ############################################################################################# | ||
| # Convert to Lightning # | ||
| # # | ||
| # By converting to Lightning, not-only your research code becomes inter-operable # | ||
| # (can easily be shared), but you get access to hundreds of extra features to make your # | ||
| # research faster. # | ||
| # Check `Facebook` blogpost on how `Lightning` enabled their research to scale at scale # | ||
| # On https://ai.facebook.com/blog # | ||
| # /reengineering-facebook-ais-deep-learning-platforms-for-interoperability/ # | ||
| ############################################################################################# | ||
|
|
||
| from pytorch_lightning import LightningDataModule, LightningModule, Trainer # noqa E402 | ||
|
|
||
|
|
||
| class LightningBoringModel(LightningModule): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.layer = torch.nn.Linear(32, 2) | ||
|
|
||
| def forward(self, x): | ||
| x = self.layer(x) | ||
| return torch.nn.functional.mse_loss(x, torch.ones_like(x)) | ||
|
|
||
| # LightningModule hooks | ||
| def training_step(self, batch, batch_idx): | ||
| x = self.forward(batch) | ||
| self.log("train_loss", x) | ||
| return x | ||
|
|
||
| def validation_step(self, batch, batch_idx): | ||
| x = self.forward(batch) | ||
| self.log("val_loss", x) | ||
| return x | ||
|
|
||
| def configure_optimizers(self): | ||
| return configure_optimizers(self) | ||
|
|
||
|
|
||
| class BoringDataModule(LightningDataModule): | ||
| def train_dataloader(self): | ||
| return train_dataloader() | ||
|
|
||
| def val_dataloader(self): | ||
| return val_dataloader() | ||
|
|
||
|
|
||
| seed_everything(42) | ||
| lightning_module = LightningBoringModel() | ||
| datamodule = BoringDataModule() | ||
| trainer = Trainer(max_epochs=10) | ||
| trainer.fit(lightning_module, datamodule) | ||
|
|
||
|
|
||
| ############################################################################################# | ||
| # Assert the weights are the same # | ||
| ############################################################################################# | ||
|
|
||
| for pure_w, lite_w in zip(pure_model_weights.values(), lightning_module.state_dict().values()): | ||
| torch.equal(pure_w, lite_w) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.