-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add a notebook example to reach a quick baseline of ~94% accuracy on CIFAR #4818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
bd5bc96
Add a notebook example to reach a quick baseline of ~94% accuracy on …
hemildesai b4d2c3d
Remove outputs
hemildesai 8d20646
PR Feedback
hemildesai 94d6ca5
Merge branch 'master' into cifar10-baseline
tchaton 366e797
some changes
rohitgr7 e58c20c
Merge branch 'master' into cifar10-baseline
rohitgr7 d3913b9
some more changes
rohitgr7 7e0bd65
Merge branch 'master' into cifar10-baseline
rohitgr7 a715751
Merge branch 'master' into cifar10-baseline
rohitgr7 331c92c
Merge branch 'master' into cifar10-baseline
rohitgr7 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,394 @@ | ||
| { | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "nbformat": 4, | ||
| "nbformat_minor": 0, | ||
| "metadata": { | ||
| "accelerator": "GPU", | ||
| "colab": { | ||
| "name": "06_cifar10_baseline.ipynb", | ||
| "provenance": [], | ||
| "collapsed_sections": [] | ||
| }, | ||
| "kernelspec": { | ||
| "display_name": "Python 3", | ||
| "name": "python3" | ||
| } | ||
| }, | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "qMDj0BYNECU8" | ||
| }, | ||
| "source": [ | ||
| "<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-cifar10-pytorch-lightning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "ECu0zDh8UXU8" | ||
| }, | ||
| "source": [ | ||
| "# PyTorch Lightning CIFAR10 ~94% Baseline Tutorial ⚡\n", | ||
| "\n", | ||
| "Train a Resnet to 94% accuracy on Cifar10!\n", | ||
| "\n", | ||
| "Main takeaways:\n", | ||
| "1. Experiment with different Learning Rate schedules and frequencies in the configure_optimizers method in pl.LightningModule\n", | ||
| "2. Use an existing Resnet architecture with modifications directly with Lightning\n", | ||
| "\n", | ||
| "---\n", | ||
| "\n", | ||
| " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", | ||
| " - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", | ||
| " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "HYpMlx7apuHq" | ||
| }, | ||
| "source": [ | ||
| "### Setup\n", | ||
| "Lightning is easy to install. Simply `pip install pytorch-lightning`.\n", | ||
| "Also check out [bolts](https://github.com/PyTorchLightning/pytorch-lightning-bolts/) for pre-existing data modules and models." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "ziAQCrE-TYWG" | ||
| }, | ||
| "source": [ | ||
| "! pip install pytorch-lightning pytorch-lightning-bolts -qU" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "L-W_Gq2FORoU" | ||
| }, | ||
| "source": [ | ||
| "# Run this if you intend to use TPUs\n", | ||
| "# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py\n", | ||
| "# !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "wjov-2N_TgeS" | ||
| }, | ||
| "source": [ | ||
| "import torch\n", | ||
| "import torch.nn as nn\n", | ||
| "import torch.nn.functional as F\n", | ||
| "from torch.optim.lr_scheduler import OneCycleLR\n", | ||
| "from torch.optim.swa_utils import AveragedModel, update_bn\n", | ||
| "import torchvision\n", | ||
| "\n", | ||
| "import pytorch_lightning as pl\n", | ||
| "from pytorch_lightning.callbacks import LearningRateMonitor\n", | ||
| "from pytorch_lightning.metrics.functional import accuracy\n", | ||
| "from pl_bolts.datamodules import CIFAR10DataModule\n", | ||
| "from pl_bolts.transforms.dataset_normalizations import cifar10_normalization" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "54JMU1N-0y0g" | ||
| }, | ||
| "source": [ | ||
| "pl.seed_everything(7);" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "FA90qwFcqIXR" | ||
| }, | ||
| "source": [ | ||
| "### CIFAR10 Data Module\n", | ||
| "\n", | ||
| "Import the existing data module from `bolts` and modify the train and test transforms." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "S9e-W8CSa8nH" | ||
| }, | ||
| "source": [ | ||
| "batch_size = 32\n", | ||
| "\n", | ||
| "train_transforms = torchvision.transforms.Compose([\n", | ||
| " torchvision.transforms.RandomCrop(32, padding=4),\n", | ||
| " torchvision.transforms.RandomHorizontalFlip(),\n", | ||
| " torchvision.transforms.ToTensor(),\n", | ||
| " cifar10_normalization(),\n", | ||
| "])\n", | ||
| "\n", | ||
| "test_transforms = torchvision.transforms.Compose([\n", | ||
| " torchvision.transforms.ToTensor(),\n", | ||
| " cifar10_normalization(),\n", | ||
| "])\n", | ||
| "\n", | ||
| "cifar10_dm = CIFAR10DataModule(\n", | ||
| " batch_size=batch_size,\n", | ||
| " train_transforms=train_transforms,\n", | ||
| " test_transforms=test_transforms,\n", | ||
| " val_transforms=test_transforms,\n", | ||
| ")" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "SfCsutp3qUMc" | ||
| }, | ||
| "source": [ | ||
| "### Resnet\n", | ||
| "Modify the pre-existing Resnet architecture from TorchVision. The pre-existing architecture is based on ImageNet images (224x224) as input. So we need to modify it for CIFAR10 images (32x32)." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "GNSeJgwvhHp-" | ||
| }, | ||
| "source": [ | ||
| "def create_model():\n", | ||
| " model = torchvision.models.resnet18(pretrained=False, num_classes=10)\n", | ||
| " model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | ||
| " model.maxpool = nn.Identity()\n", | ||
| " return model" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "HUCj5TKsqty1" | ||
| }, | ||
| "source": [ | ||
| "### Lightning Module\n", | ||
| "Check out the [`configure_optimizers`](https://pytorch-lightning.readthedocs.io/en/stable/lightning_module.html#configure-optimizers) method to use custom Learning Rate schedulers. The OneCycleLR with SGD will get you to around 92-93% accuracy in 20-30 epochs and 93-94% accuracy in 40-50 epochs. Feel free to experiment with different LR schedules from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "03OMrBa5iGtT" | ||
| }, | ||
| "source": [ | ||
| "class LitResnet(pl.LightningModule):\n", | ||
| " def __init__(self, lr=0.05):\n", | ||
| " super().__init__()\n", | ||
| "\n", | ||
| " self.save_hyperparameters()\n", | ||
| " self.model = create_model()\n", | ||
| "\n", | ||
| " def forward(self, x):\n", | ||
| " out = self.model(x)\n", | ||
| " return F.log_softmax(out, dim=1)\n", | ||
| "\n", | ||
| " def training_step(self, batch, batch_idx):\n", | ||
| " x, y = batch\n", | ||
| " logits = F.log_softmax(self.model(x), dim=1)\n", | ||
| " loss = F.nll_loss(logits, y)\n", | ||
| " self.log('train_loss', loss)\n", | ||
| " return loss\n", | ||
| "\n", | ||
| " def evaluate(self, batch, stage=None):\n", | ||
| " x, y = batch\n", | ||
| " logits = self(x)\n", | ||
| " loss = F.nll_loss(logits, y)\n", | ||
| " preds = torch.argmax(logits, dim=1)\n", | ||
| " acc = accuracy(preds, y)\n", | ||
| "\n", | ||
| " if stage:\n", | ||
| " self.log(f'{stage}_loss', loss, prog_bar=True)\n", | ||
| " self.log(f'{stage}_acc', acc, prog_bar=True)\n", | ||
| "\n", | ||
| " def validation_step(self, batch, batch_idx):\n", | ||
| " self.evaluate(batch, 'val')\n", | ||
| "\n", | ||
| " def test_step(self, batch, batch_idx):\n", | ||
| " self.evaluate(batch, 'test')\n", | ||
| "\n", | ||
| " def configure_optimizers(self):\n", | ||
| " optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n", | ||
| " steps_per_epoch = 45000 // batch_size\n", | ||
| " scheduler_dict = {\n", | ||
| " 'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),\n", | ||
| " 'interval': 'step',\n", | ||
| " }\n", | ||
| " return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "3FFPgpAFi9KU" | ||
| }, | ||
| "source": [ | ||
| "model = LitResnet(lr=0.05)\n", | ||
| "model.datamodule = cifar10_dm\n", | ||
| "\n", | ||
| "trainer = pl.Trainer(\n", | ||
| " progress_bar_refresh_rate=20,\n", | ||
| " max_epochs=40,\n", | ||
| " gpus=1,\n", | ||
| " logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='resnet'),\n", | ||
| " callbacks=[LearningRateMonitor(logging_interval='step')],\n", | ||
| ")\n", | ||
| "\n", | ||
| "trainer.fit(model, cifar10_dm)\n", | ||
| "trainer.test(model, datamodule=cifar10_dm);" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "lWL_WpeVIXWQ" | ||
| }, | ||
| "source": [ | ||
| "### Bonus: Use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) to get a boost on performance\n", | ||
| "\n", | ||
| "Use SWA from torch.optim to get a quick performance boost. Also shows a couple of cool features from Lightning:\n", | ||
| "- Use `training_epoch_end` to run code after the end of every epoch\n", | ||
| "- Use a pretrained model directly with this wrapper for SWA" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "bsSwqKv0t9uY" | ||
| }, | ||
| "source": [ | ||
| "class SWAResnet(LitResnet):\n", | ||
| " def __init__(self, trained_model, lr=0.01):\n", | ||
| " super().__init__()\n", | ||
| "\n", | ||
| " self.save_hyperparameters('lr')\n", | ||
| " self.model = trained_model\n", | ||
| " self.swa_model = AveragedModel(self.model)\n", | ||
| "\n", | ||
| " def forward(self, x):\n", | ||
| " out = self.swa_model(x)\n", | ||
| " return F.log_softmax(out, dim=1)\n", | ||
| "\n", | ||
| " def training_epoch_end(self, training_step_outputs):\n", | ||
| " self.swa_model.update_parameters(self.model)\n", | ||
| "\n", | ||
| " def validation_step(self, batch, batch_idx, stage=None):\n", | ||
| " x, y = batch\n", | ||
| " logits = F.log_softmax(self.model(x), dim=1)\n", | ||
| " loss = F.nll_loss(logits, y)\n", | ||
| " preds = torch.argmax(logits, dim=1)\n", | ||
| " acc = accuracy(preds, y)\n", | ||
| "\n", | ||
| " self.log(f'val_loss', loss, prog_bar=True)\n", | ||
| " self.log(f'val_acc', acc, prog_bar=True)\n", | ||
| "\n", | ||
| " def configure_optimizers(self):\n", | ||
| " optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)\n", | ||
| " return optimizer\n", | ||
| "\n", | ||
| " def on_train_end(self):\n", | ||
| " update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "cA6ZG7C74rjL" | ||
| }, | ||
| "source": [ | ||
| "swa_model = SWAResnet(model.model, lr=0.01)\n", | ||
| "swa_model.datamodule = cifar10_dm\n", | ||
| "\n", | ||
| "swa_trainer = pl.Trainer(\n", | ||
| " progress_bar_refresh_rate=20,\n", | ||
| " max_epochs=20,\n", | ||
| " gpus=1,\n", | ||
| " logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='swa_resnet'),\n", | ||
| ")\n", | ||
| "\n", | ||
| "swa_trainer.fit(swa_model, cifar10_dm)\n", | ||
| "swa_trainer.test(swa_model, datamodule=cifar10_dm);" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "metadata": { | ||
| "id": "RRHMfGiDpZ2M" | ||
| }, | ||
| "source": [ | ||
| "# Start tensorboard.\n", | ||
| "%reload_ext tensorboard\n", | ||
| "%tensorboard --logdir lightning_logs/" | ||
| ], | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "RltpFGS-s0M1" | ||
| }, | ||
| "source": [ | ||
| "<code style=\"color:#792ee5;\">\n", | ||
| " <h1> <strong> Congratulations - Time to Join the Community! </strong> </h1>\n", | ||
| "</code>\n", | ||
| "\n", | ||
| "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!\n", | ||
| "\n", | ||
| "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", | ||
| "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", | ||
| "\n", | ||
| "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", | ||
| "\n", | ||
| "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)!\n", | ||
| "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", | ||
| "\n", | ||
| "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", | ||
| "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", | ||
| "\n", | ||
| "* Please, star [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts)\n", | ||
| "\n", | ||
| "### Contributions !\n", | ||
| "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/pytorch-lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", | ||
| "\n", | ||
| "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", | ||
| "* [Bolt good first issue](https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", | ||
| "* You can also contribute your own notebooks with useful examples !\n", | ||
| "\n", | ||
| "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", | ||
| "\n", | ||
| "<img src=\"https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/logos/lightning_logo-name.png?raw=true\" width=\"800\" height=\"200\" />" | ||
| ] | ||
| } | ||
| ] | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.