Skip to content

Commit c614cf0

Browse files
tchatonawaelchlicarmocca
authored
Improve Lite Examples (#10195)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 81636fe commit c614cf0

File tree

17 files changed

+216
-239
lines changed

17 files changed

+216
-239
lines changed

docs/source/advanced/mixed_precision.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain
5050
Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision.
5151

5252
.. testcode::
53-
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available()
53+
:skipif: not _TORCH_GREATER_EQUAL_1_10 or not torch.cuda.is_available()
5454

5555
Trainer(gpus=1, precision="bf16")
5656

5757
It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN under the hood.
5858

5959
.. testcode::
60-
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10
60+
:skipif: not _TORCH_GREATER_EQUAL_1_10
6161

6262
Trainer(precision="bf16")
6363

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def package_list_from_file(file):
377377
_XLA_AVAILABLE,
378378
_TPU_AVAILABLE,
379379
_TORCHVISION_AVAILABLE,
380-
_TORCH_GREATER_EQUAL_DEV_1_10,
380+
_TORCH_GREATER_EQUAL_1_10,
381381
_module_available,
382382
)
383383
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")

pl_examples/basic_examples/README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Use these examples to test how Lightning works.
66

77
5 MNIST examples showing how to gradually convert from pure PyTorch to PyTorch Lightning.
88

9-
The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might helpful to learn about it.
9+
The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might be helpful to learn about it.
1010

1111
#### 1 . Image Classifier with Vanilla PyTorch
1212

@@ -21,7 +21,7 @@ ______________________________________________________________________
2121

2222
#### 2. Image Classifier with LightningLite
2323

24-
Trains a simple CNN over MNIST using [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst).
24+
This script shows you how to scale the previous script to enable GPU and multi GPU training using [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst).
2525

2626
```bash
2727
# cpu / multiple gpus if available
@@ -30,7 +30,10 @@ python mnist_examples/image_classifier_2_lite.py
3030

3131
______________________________________________________________________
3232

33-
Trains a simple CNN over MNIST where `LightningLite` is almost a `LightningModule`.
33+
#### 3. Image Classifier - Conversion Lite to Lightning
34+
35+
This script shows you to prepare your conversion from [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst)
36+
to `LightningModule`.
3437

3538
```bash
3639
# cpu / multiple gpus if available
@@ -41,7 +44,7 @@ ______________________________________________________________________
4144

4245
#### 4. Image Classifier with LightningModule
4346

44-
Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule`.
47+
This script shows you how the result of the conversion to the `LightningModule` and finally get all the benefits from Lightning.
4548

4649
```bash
4750
# cpu
@@ -55,7 +58,7 @@ ______________________________________________________________________
5558

5659
#### 5. Image Classifier with LightningModule + LightningDataModule
5760

58-
Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule` and `LightningDataModule`
61+
This script shows you how extracts the data related components to a `LightningDataModule`.
5962

6063
```bash
6164
# cpu
@@ -64,8 +67,8 @@ python mnist_examples/image_classifier_5_lightning_datamodule.py
6467
# gpus (any number)
6568
python mnist_examples/image_classifier_5_lightning_datamodule.py --trainer.gpus 2
6669

67-
# Distributed Data Parallel
68-
python backbone_image_classifier.py --trainer.gpus 2 --trainer.accelerator ddp
70+
# data parallel
71+
python mnist_examples/image_classifier_5_lightning_datamodule.py --trainer.gpus 2 --trainer.accelerator 'dp'
6972
```
7073

7174
______________________________________________________________________

pl_examples/basic_examples/mnist_examples/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
5 MNIST examples showing how to gradually convert from pure PyTorch to PyTorch Lightning.
44

5-
The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might helpful to learn about it.
5+
The transition through [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst) from pure PyTorch is optional but it might be helpful to learn about it.
66

77
#### 1 . Image Classifier with Vanilla PyTorch
88

@@ -17,7 +17,7 @@ ______________________________________________________________________
1717

1818
#### 2. Image Classifier with LightningLite
1919

20-
Trains a simple CNN over MNIST using [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst).
20+
This script shows you how to scale the previous script to enable GPU and multi GPU training using [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst).
2121

2222
```bash
2323
# cpu / multiple gpus if available
@@ -28,7 +28,8 @@ ______________________________________________________________________
2828

2929
#### 3. Image Classifier - Conversion Lite to Lightning
3030

31-
Trains a simple CNN over MNIST where `LightningLite` is almost a `LightningModule`.
31+
This script shows you to prepare your conversion from [LightningLite](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.rst)
32+
to `LightningModule`.
3233

3334
```bash
3435
# cpu / multiple gpus if available
@@ -39,21 +40,21 @@ ______________________________________________________________________
3940

4041
#### 4. Image Classifier with LightningModule
4142

42-
Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule`.
43+
This script shows you how the result of the conversion to the `LightningModule` and finally get all the benefits from Lightning.
4344

4445
```bash
4546
# cpu
46-
python mnist_examples/image_classifier_4_lightning.py
47+
python image_classifier_4_lightning.py
4748

4849
# gpus (any number)
49-
python mnist_examples/image_classifier_4_lightning.py --trainer.gpus 2
50+
python image_classifier_4_lightning.py --trainer.gpus 2
5051
```
5152

5253
______________________________________________________________________
5354

5455
#### 5. Image Classifier with LightningModule + LightningDataModule
5556

56-
Trains a simple CNN over MNIST with `Lightning Trainer` and the converted `LightningModule` and `LightningDataModule`
57+
This script shows you how extracts the data related components to a `LightningDataModule`.
5758

5859
```bash
5960
# cpu

pl_examples/basic_examples/mnist_examples/image_classifier_1_pytorch.py

Lines changed: 70 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -52,64 +52,90 @@ def forward(self, x):
5252
return output
5353

5454

55-
def train(args, model, device, train_loader, optimizer, epoch):
56-
model.train()
57-
for batch_idx, (data, target) in enumerate(train_loader):
58-
data, target = data.to(device), target.to(device)
59-
optimizer.zero_grad()
60-
output = model(data)
61-
loss = F.nll_loss(output, target)
62-
loss.backward()
63-
optimizer.step()
64-
if (batch_idx == 0) or ((batch_idx + 1) % args.log_interval == 0):
65-
print(
66-
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
67-
epoch,
68-
batch_idx * len(data),
69-
len(train_loader.dataset),
70-
100.0 * batch_idx / len(train_loader),
71-
loss.item(),
72-
)
73-
)
74-
if args.dry_run:
75-
break
55+
def run(hparams):
56+
57+
torch.manual_seed(hparams.seed)
58+
59+
use_cuda = torch.cuda.is_available()
60+
device = torch.device("cuda" if use_cuda else "cpu")
61+
62+
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
63+
train_dataset = MNIST("./data", train=True, download=True, transform=transform)
64+
test_dataset = MNIST("./data", train=False, transform=transform)
65+
train_loader = torch.utils.data.DataLoader(
66+
train_dataset,
67+
batch_size=hparams.batch_size,
68+
)
69+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size)
7670

71+
model = Net().to(device)
72+
optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr)
73+
74+
scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma)
75+
76+
# EPOCH LOOP
77+
for epoch in range(1, hparams.epochs + 1):
7778

78-
def test(args, model, device, test_loader):
79-
model.eval()
80-
test_loss = 0
81-
correct = 0
82-
with torch.no_grad():
83-
for data, target in test_loader:
79+
# TRAINING LOOP
80+
model.train()
81+
for batch_idx, (data, target) in enumerate(train_loader):
8482
data, target = data.to(device), target.to(device)
83+
optimizer.zero_grad()
8584
output = model(data)
86-
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
87-
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
88-
correct += pred.eq(target.view_as(pred)).sum().item()
89-
if args.dry_run:
90-
break
91-
92-
test_loss /= len(test_loader.dataset)
85+
loss = F.nll_loss(output, target)
86+
loss.backward()
87+
optimizer.step()
88+
if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0):
89+
print(
90+
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
91+
epoch,
92+
batch_idx * len(data),
93+
len(train_loader.dataset),
94+
100.0 * batch_idx / len(train_loader),
95+
loss.item(),
96+
)
97+
)
98+
if hparams.dry_run:
99+
break
100+
scheduler.step()
93101

94-
print(
95-
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
96-
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
102+
# TESTING LOOP
103+
model.eval()
104+
test_loss = 0
105+
correct = 0
106+
with torch.no_grad():
107+
for data, target in test_loader:
108+
data, target = data.to(device), target.to(device)
109+
output = model(data)
110+
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
111+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
112+
correct += pred.eq(target.view_as(pred)).sum().item()
113+
if hparams.dry_run:
114+
break
115+
116+
test_loss /= len(test_loader.dataset)
117+
118+
print(
119+
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
120+
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
121+
)
97122
)
98-
)
123+
124+
if hparams.dry_run:
125+
break
126+
127+
if hparams.save_model:
128+
torch.save(model.state_dict(), "mnist_cnn.pt")
99129

100130

101131
def main():
102132
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
103133
parser.add_argument(
104134
"--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
105135
)
106-
parser.add_argument(
107-
"--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)"
108-
)
109136
parser.add_argument("--epochs", type=int, default=14, metavar="N", help="number of epochs to train (default: 14)")
110137
parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)")
111138
parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)")
112-
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
113139
parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass")
114140
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
115141
parser.add_argument(
@@ -120,40 +146,8 @@ def main():
120146
help="how many batches to wait before logging training status",
121147
)
122148
parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model")
123-
args = parser.parse_args()
124-
use_cuda = not args.no_cuda and torch.cuda.is_available()
125-
126-
torch.manual_seed(args.seed)
127-
128-
device = torch.device("cuda" if use_cuda else "cpu")
129-
130-
train_kwargs = {"batch_size": args.batch_size}
131-
test_kwargs = {"batch_size": args.test_batch_size}
132-
if use_cuda:
133-
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
134-
train_kwargs.update(cuda_kwargs)
135-
test_kwargs.update(cuda_kwargs)
136-
137-
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
138-
train_dataset = MNIST("./data", train=True, download=True, transform=transform)
139-
test_dataset = MNIST("./data", train=False, transform=transform)
140-
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
141-
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)
142-
143-
model = Net().to(device)
144-
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
145-
146-
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
147-
for epoch in range(1, args.epochs + 1):
148-
train(args, model, device, train_loader, optimizer, epoch)
149-
test(args, model, device, test_loader)
150-
scheduler.step()
151-
152-
if args.dry_run:
153-
break
154-
155-
if args.save_model:
156-
torch.save(model.state_dict(), "mnist_cnn.pt")
149+
hparams = parser.parse_args()
150+
run(hparams)
157151

158152

159153
if __name__ == "__main__":

0 commit comments

Comments
 (0)