Skip to content

Commit d566082

Browse files
authored
Add Mnist examples with lite (#10131)
Add MNIST PyTorch to Lightning examples
1 parent 889e319 commit d566082

20 files changed

+973
-388
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ ENV/
137137
Datasets/
138138
mnist/
139139
legacy/checkpoints/
140+
*.gz
141+
*ubyte
140142

141143
# pl tests
142144
ml-runs/

grid_generated_0.png

6.84 KB
Loading

grid_ori_0.png

1.19 KB
Loading

pl_examples/README.md

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,31 @@ can be found in our sister library [lightning-bolts](https://pytorch-lightning.r
55

66
______________________________________________________________________
77

8-
## Basic examples
8+
## MNIST Examples
99

10-
In this folder we add 3 simple examples:
10+
5 MNIST examples showing how to gradually convert from pure PyTorch to PyTorch Lightning.
11+
12+
The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might helpful to learn about it.
13+
14+
- [MNIST with vanilla PyTorch](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_1_pytorch.py)
15+
- [MNIST with LightningLite](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_2_lite.py)
16+
- [MNIST LightningLite to LightningModule](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_3_lite_to_lightning.py)
17+
- [MNIST with LightningModule](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_4_lightning.py)
18+
- [MNIST with LightningModule + LightningDataModule](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py)
19+
20+
______________________________________________________________________
21+
22+
## Basic Examples
23+
24+
In this folder, we add 2 simple examples:
1125

12-
- [MNIST Classifier](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/simple_image_classifier.py) (defines the model inside the `LightningModule`).
1326
- [Image Classifier](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/backbone_image_classifier.py) (trains arbitrary datasets with arbitrary backbones).
27+
- [Image Classifier + DALI](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist_examples/image_classifier_4_dali.py) (defines the model inside the `LightningModule`).
1428
- [Autoencoder](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/autoencoder.py) (shows how the `LightningModule` can be used as a system)
1529

1630
______________________________________________________________________
1731

18-
## Domain examples
32+
## Domain Examples
1933

2034
This folder contains older examples. You should instead use the examples
2135
in [lightning-bolts](https://pytorch-lightning.readthedocs.io/en/latest/ecosystem/bolts.html)

pl_examples/basic_examples/README.md

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,70 @@
22

33
Use these examples to test how lightning works.
44

5-
#### MNIST
5+
## MNIST Examples
66

7-
Trains MNIST where the model is defined inside the `LightningModule`.
7+
5 MNIST examples showing how to gradually convert from pure PyTorch to PyTorch Lightning.
8+
9+
The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might helpful to learn about it.
10+
11+
#### 1 . Image Classifier with Vanilla PyTorch
12+
13+
Trains a simple CNN over MNIST using vanilla PyTorch.
814

915
```bash
1016
# cpu
11-
python simple_image_classifier.py
17+
python mnist_examples/image_classifier_1_pytorch.py
18+
```
1219

13-
# gpus (any number)
14-
python simple_image_classifier.py --trainer.gpus 2
20+
______________________________________________________________________
1521

16-
# dataparallel
17-
python simple_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp'
22+
#### 2. Image Classifier with LightningLite
23+
24+
Trains a simple CNN over MNIST using [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst).
25+
26+
```bash
27+
# cpu / multiple gpus if available
28+
python mnist_examples/image_classifier_2_lite.py
1829
```
1930

2031
______________________________________________________________________
2132

22-
#### MNIST with DALI
33+
Trains a simple CNN over MNIST where `LightningLite` is almost a `LightningModule`.
2334

24-
The MNIST example above using [NVIDIA DALI](https://developer.nvidia.com/DALI).
25-
Requires NVIDIA DALI to be installed based on your CUDA version, see [here](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html).
35+
```bash
36+
# cpu / multiple gpus if available
37+
python mnist_examples/image_classifier_3_lite_to_lightning.py
38+
```
39+
40+
______________________________________________________________________
41+
42+
#### 4. Image Classifier with LightningModule
43+
44+
Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule`.
2645

2746
```bash
28-
python dali_image_classifier.py
47+
# cpu
48+
python mnist_examples/image_classifier_4_lightning.py
49+
50+
# gpus (any number)
51+
python mnist_examples/image_classifier_4_lightning.py --trainer.gpus 2
2952
```
3053

3154
______________________________________________________________________
3255

33-
#### Image classifier
56+
#### 5. Image Classifier with LightningModule + LightningDataModule
3457

35-
Generic image classifier with an arbitrary backbone (ie: a simple system)
58+
Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule` and `LightningDataModule`
3659

3760
```bash
3861
# cpu
39-
python backbone_image_classifier.py
62+
python mnist_examples/image_classifier_5_lightning_datamodule.py
4063

4164
# gpus (any number)
42-
python backbone_image_classifier.py --trainer.gpus 2
65+
python mnist_examples/image_classifier_5_lightning_datamodule.py --trainer.gpus 2
4366

44-
# dataparallel
45-
python backbone_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp'
67+
# data parallel
68+
python mnist_examples/image_classifier_5_lightning_datamodule.py --trainer.gpus 2 --trainer.accelerator 'dp'
4669
```
4770

4871
______________________________________________________________________
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
## MNIST Examples
2+
3+
5 MNIST examples showing how to gradually convert from pure PyTorch to PyTorch Lightning.
4+
5+
The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might helpful to learn about it.
6+
7+
#### 1 . Image Classifier with Vanilla PyTorch
8+
9+
Trains a simple CNN over MNIST using vanilla PyTorch.
10+
11+
```bash
12+
# cpu
13+
python image_classifier_1_pytorch.py
14+
```
15+
16+
______________________________________________________________________
17+
18+
#### 2. Image Classifier with LightningLite
19+
20+
Trains a simple CNN over MNIST using [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst).
21+
22+
```bash
23+
# cpu / multiple gpus if available
24+
python image_classifier_2_lite.py
25+
```
26+
27+
______________________________________________________________________
28+
29+
#### 3. Image Classifier - Conversion Lite to Lightning
30+
31+
Trains a simple CNN over MNIST where `LightningLite` is almost a `LightningModule`.
32+
33+
```bash
34+
# cpu / multiple gpus if available
35+
python image_classifier_3_lite_to_lightning.py
36+
```
37+
38+
______________________________________________________________________
39+
40+
#### 4. Image Classifier with LightningModule
41+
42+
Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule`.
43+
44+
```bash
45+
# cpu
46+
python mnist_examples/image_classifier_4_lightning.py
47+
48+
# gpus (any number)
49+
python mnist_examples/image_classifier_4_lightning.py --trainer.gpus 2
50+
```
51+
52+
______________________________________________________________________
53+
54+
#### 5. Image Classifier with LightningModule + LightningDataModule
55+
56+
Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule` and `LightningDataModule`
57+
58+
```bash
59+
# cpu
60+
python image_classifier_5_lightning_datamodule.py
61+
62+
# gpus (any number)
63+
python image_classifier_5_lightning_datamodule.py --trainer.gpus 2
64+
65+
# dataparallel
66+
python image_classifier_5_lightning_datamodule.py --trainer.gpus 2 --trainer.accelerator 'dp'
67+
```
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import argparse
15+
16+
import torch
17+
import torch.nn as nn
18+
import torch.nn.functional as F
19+
import torch.optim as optim
20+
import torchvision.transforms as T
21+
from torch.optim.lr_scheduler import StepLR
22+
23+
from pl_examples.basic_examples.mnist_datamodule import MNIST
24+
25+
# Credit to the PyTorch Team
26+
# Taken from https://github.com/pytorch/examples/blob/master/mnist/main.py and slightly adapted.
27+
28+
29+
class Net(nn.Module):
30+
def __init__(self):
31+
super().__init__()
32+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
33+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
34+
self.dropout1 = nn.Dropout(0.25)
35+
self.dropout2 = nn.Dropout(0.5)
36+
self.fc1 = nn.Linear(9216, 128)
37+
self.fc2 = nn.Linear(128, 10)
38+
39+
def forward(self, x):
40+
x = self.conv1(x)
41+
x = F.relu(x)
42+
x = self.conv2(x)
43+
x = F.relu(x)
44+
x = F.max_pool2d(x, 2)
45+
x = self.dropout1(x)
46+
x = torch.flatten(x, 1)
47+
x = self.fc1(x)
48+
x = F.relu(x)
49+
x = self.dropout2(x)
50+
x = self.fc2(x)
51+
output = F.log_softmax(x, dim=1)
52+
return output
53+
54+
55+
def train(args, model, device, train_loader, optimizer, epoch):
56+
model.train()
57+
for batch_idx, (data, target) in enumerate(train_loader):
58+
data, target = data.to(device), target.to(device)
59+
optimizer.zero_grad()
60+
output = model(data)
61+
loss = F.nll_loss(output, target)
62+
loss.backward()
63+
optimizer.step()
64+
if (batch_idx == 0) or ((batch_idx + 1) % args.log_interval == 0):
65+
print(
66+
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
67+
epoch,
68+
batch_idx * len(data),
69+
len(train_loader.dataset),
70+
100.0 * batch_idx / len(train_loader),
71+
loss.item(),
72+
)
73+
)
74+
if args.dry_run:
75+
break
76+
77+
78+
def test(args, model, device, test_loader):
79+
model.eval()
80+
test_loss = 0
81+
correct = 0
82+
with torch.no_grad():
83+
for data, target in test_loader:
84+
data, target = data.to(device), target.to(device)
85+
output = model(data)
86+
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
87+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
88+
correct += pred.eq(target.view_as(pred)).sum().item()
89+
if args.dry_run:
90+
break
91+
92+
test_loss /= len(test_loader.dataset)
93+
94+
print(
95+
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
96+
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
97+
)
98+
)
99+
100+
101+
def main():
102+
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
103+
parser.add_argument(
104+
"--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
105+
)
106+
parser.add_argument(
107+
"--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)"
108+
)
109+
parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)")
110+
parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)")
111+
parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)")
112+
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
113+
parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass")
114+
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
115+
parser.add_argument(
116+
"--log-interval",
117+
type=int,
118+
default=10,
119+
metavar="N",
120+
help="how many batches to wait before logging training status",
121+
)
122+
parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model")
123+
args = parser.parse_args()
124+
use_cuda = not args.no_cuda and torch.cuda.is_available()
125+
126+
torch.manual_seed(args.seed)
127+
128+
device = torch.device("cuda" if use_cuda else "cpu")
129+
130+
train_kwargs = {"batch_size": args.batch_size}
131+
test_kwargs = {"batch_size": args.test_batch_size}
132+
if use_cuda:
133+
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
134+
train_kwargs.update(cuda_kwargs)
135+
test_kwargs.update(cuda_kwargs)
136+
137+
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
138+
train_dataset = MNIST("./data", train=True, download=True, transform=transform)
139+
test_dataset = MNIST("./data", train=False, transform=transform)
140+
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
141+
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
142+
143+
model = Net().to(device)
144+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
145+
146+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
147+
for epoch in range(1, args.epochs + 1):
148+
train(args, model, device, train_loader, optimizer, epoch)
149+
test(args, model, device, test_loader)
150+
scheduler.step()
151+
152+
if args.dry_run:
153+
break
154+
155+
if args.save_model:
156+
torch.save(model.state_dict(), "mnist_cnn.pt")
157+
158+
159+
if __name__ == "__main__":
160+
main()

0 commit comments

Comments
 (0)