Skip to content

Commit 6e5f232

Browse files
irustandiSeanNarenBordaSeanNaren
authored
Add Dali MNIST example (#3721)
* add MNIST DALI example, update README.md * Fix PEP8 warnings * reformatted using black * add mnist_dali to test_examples.py * Add documentation as docstrings * add nvidia-pyindex and nvidia-dali-cuda100 * replace nvidia-pyindex with --extra-index-url * mark mnist_dali test as Linux and GPU only * adjust CUDA docker and examples.txt, fix import error in test_examples.py * adjust the GPU check * Exit when DALI is not available * remove requirements-examples.txt and DALI pip install * Refactored example, moved to new logging api, added runtime check for test and dali script * Patch to reflect the mnist example module * add req. * Apply suggestions from code review * Removed requirement as it breaks CPU install, added note in README to install DALI * add DALI to Drone * test examples * Apply suggestions from code review * imports * ABC * cuda * cuda * pip DALI * Move build into init function Co-authored-by: SeanNaren <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Sean Naren <[email protected]>
1 parent f3dfb98 commit 6e5f232

File tree

5 files changed

+249
-10
lines changed

5 files changed

+249
-10
lines changed

.drone.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ steps:
3232
- pip --version
3333
- nvidia-smi
3434
- pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir
35+
# when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0"
36+
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed
3537
- pip list
3638
- coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --color=yes --durations=25 # --flake8
3739
- python -m pytest benchmarks pl_examples -v --color=yes --maxfail=2 --durations=0 # --flake8

pl_examples/basic_examples/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@ python mnist.py
1414
python mnist.py --gpus 2 --distributed_backend 'dp'
1515
```
1616

17-
---
17+
---
18+
#### MNIST with DALI
19+
The MNIST example above using [NVIDIA DALI](https://developer.nvidia.com/DALI).
20+
Requires NVIDIA DALI to be installed based on your CUDA version, see [here](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html).
21+
```bash
22+
python mnist_dali.py
23+
```
24+
25+
---
1826
#### Image classifier
1927
Generic image classifier with an arbitrary backbone (ie: a simple system)
2028
```bash
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from abc import ABC
15+
from argparse import ArgumentParser
16+
from random import shuffle
17+
from warnings import warn
18+
19+
import numpy as np
20+
import torch
21+
from torch.nn import functional as F
22+
from torch.utils.data import random_split
23+
24+
import pytorch_lightning as pl
25+
26+
try:
27+
from torchvision.datasets.mnist import MNIST
28+
from torchvision import transforms
29+
except Exception:
30+
from tests.base.datasets import MNIST
31+
32+
try:
33+
import nvidia.dali.ops as ops
34+
import nvidia.dali.types as types
35+
from nvidia.dali.pipeline import Pipeline
36+
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
37+
except (ImportError, ModuleNotFoundError):
38+
warn('NVIDIA DALI is not available')
39+
ops, types, Pipeline, DALIClassificationIterator = ..., ..., ABC, ABC
40+
41+
42+
class ExternalMNISTInputIterator(object):
43+
"""
44+
This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches
45+
"""
46+
47+
def __init__(self, mnist_ds, batch_size):
48+
self.batch_size = batch_size
49+
self.mnist_ds = mnist_ds
50+
self.indices = list(range(len(self.mnist_ds)))
51+
shuffle(self.indices)
52+
53+
def __iter__(self):
54+
self.i = 0
55+
self.n = len(self.mnist_ds)
56+
return self
57+
58+
def __next__(self):
59+
batch = []
60+
labels = []
61+
for _ in range(self.batch_size):
62+
index = self.indices[self.i]
63+
img, label = self.mnist_ds[index]
64+
batch.append(img.numpy())
65+
labels.append(np.array([label], dtype=np.uint8))
66+
self.i = (self.i + 1) % self.n
67+
return (batch, labels)
68+
69+
70+
class ExternalSourcePipeline(Pipeline):
71+
"""
72+
This DALI pipeline class just contains the MNIST iterator
73+
"""
74+
75+
def __init__(self, batch_size, eii, num_threads, device_id):
76+
super(ExternalSourcePipeline, self).__init__(batch_size, num_threads, device_id, seed=12)
77+
self.source = ops.ExternalSource(source=eii, num_outputs=2)
78+
self.build()
79+
80+
def define_graph(self):
81+
images, labels = self.source()
82+
return images, labels
83+
84+
85+
class DALIClassificationLoader(DALIClassificationIterator):
86+
"""
87+
This class extends DALI's original DALIClassificationIterator with the __len__() function so that we can call len() on it
88+
"""
89+
90+
def __init__(
91+
self,
92+
pipelines,
93+
size=-1,
94+
reader_name=None,
95+
auto_reset=False,
96+
fill_last_batch=True,
97+
dynamic_shape=False,
98+
last_batch_padded=False,
99+
):
100+
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded)
101+
102+
def __len__(self):
103+
batch_count = self._size // (self._num_gpus * self.batch_size)
104+
last_batch = 1 if self._fill_last_batch else 0
105+
return batch_count + last_batch
106+
107+
108+
class LitClassifier(pl.LightningModule):
109+
def __init__(self, hidden_dim=128, learning_rate=1e-3):
110+
super().__init__()
111+
self.save_hyperparameters()
112+
113+
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
114+
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
115+
116+
def forward(self, x):
117+
x = x.view(x.size(0), -1)
118+
x = torch.relu(self.l1(x))
119+
x = torch.relu(self.l2(x))
120+
return x
121+
122+
def split_batch(self, batch):
123+
return batch[0]["data"], batch[0]["label"].squeeze().long()
124+
125+
def training_step(self, batch, batch_idx):
126+
x, y = self.split_batch(batch)
127+
y_hat = self(x)
128+
loss = F.cross_entropy(y_hat, y)
129+
return loss
130+
131+
def validation_step(self, batch, batch_idx):
132+
x, y = self.split_batch(batch)
133+
y_hat = self(x)
134+
loss = F.cross_entropy(y_hat, y)
135+
self.log('valid_loss', loss)
136+
137+
def test_step(self, batch, batch_idx):
138+
x, y = self.split_batch(batch)
139+
y_hat = self(x)
140+
loss = F.cross_entropy(y_hat, y)
141+
self.log('test_loss', loss)
142+
143+
def configure_optimizers(self):
144+
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
145+
146+
@staticmethod
147+
def add_model_specific_args(parent_parser):
148+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
149+
parser.add_argument('--hidden_dim', type=int, default=128)
150+
parser.add_argument('--learning_rate', type=float, default=0.0001)
151+
return parser
152+
153+
154+
def cli_main():
155+
pl.seed_everything(1234)
156+
157+
# ------------
158+
# args
159+
# ------------
160+
parser = ArgumentParser()
161+
parser.add_argument('--batch_size', default=32, type=int)
162+
parser = pl.Trainer.add_argparse_args(parser)
163+
parser = LitClassifier.add_model_specific_args(parser)
164+
args = parser.parse_args()
165+
166+
# ------------
167+
# data
168+
# ------------
169+
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
170+
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
171+
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
172+
173+
eii_train = ExternalMNISTInputIterator(mnist_train, args.batch_size)
174+
eii_val = ExternalMNISTInputIterator(mnist_val, args.batch_size)
175+
eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size)
176+
177+
pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0)
178+
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False)
179+
180+
pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0)
181+
val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False)
182+
183+
pipe_test = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_test, num_threads=2, device_id=0)
184+
test_loader = DALIClassificationLoader(pipe_test, size=len(mnist_test), auto_reset=True, fill_last_batch=False)
185+
186+
# ------------
187+
# model
188+
# ------------
189+
model = LitClassifier(args.hidden_dim, args.learning_rate)
190+
191+
# ------------
192+
# training
193+
# ------------
194+
trainer = pl.Trainer.from_argparse_args(args)
195+
trainer.fit(model, train_loader, val_loader)
196+
197+
# ------------
198+
# testing
199+
# ------------
200+
trainer.test(test_dataloaders=test_loader)
201+
202+
203+
if __name__ == "__main__":
204+
cli_main()

pl_examples/test_examples.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
import platform
12
from unittest import mock
2-
import torch
3+
34
import pytest
5+
import torch
6+
7+
try:
8+
from nvidia.dali import ops, types, pipeline, plugin
9+
except (ImportError, ModuleNotFoundError):
10+
DALI_AVAILABLE = False
11+
else:
12+
DALI_AVAILABLE = True
413

514
dp_16_args = """
615
--max_epochs 1 \
@@ -28,7 +37,7 @@
2837
--precision 16 \
2938
"""
3039

31-
40+
# TODO
3241
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
3342
# @pytest.mark.parametrize('cli_args', [dp_16_args])
3443
# def test_examples_dp_mnist(cli_args):
@@ -38,15 +47,17 @@
3847
# cli_main()
3948

4049

50+
# TODO
4151
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
4252
# @pytest.mark.parametrize('cli_args', [dp_16_args])
4353
# def test_examples_dp_image_classifier(cli_args):
4454
# from pl_examples.basic_examples.image_classifier import cli_main
4555
#
4656
# with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
4757
# cli_main()
48-
#
49-
#
58+
59+
60+
# TODO
5061
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
5162
# @pytest.mark.parametrize('cli_args', [dp_16_args])
5263
# def test_examples_dp_autoencoder(cli_args):
@@ -56,24 +67,27 @@
5667
# cli_main()
5768

5869

70+
# TODO
5971
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
6072
# @pytest.mark.parametrize('cli_args', [ddp_args])
6173
# def test_examples_ddp_mnist(cli_args):
6274
# from pl_examples.basic_examples.mnist import cli_main
6375
#
6476
# with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
6577
# cli_main()
66-
#
67-
#
78+
79+
80+
# TODO
6881
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
6982
# @pytest.mark.parametrize('cli_args', [ddp_args])
7083
# def test_examples_ddp_image_classifier(cli_args):
7184
# from pl_examples.basic_examples.image_classifier import cli_main
7285
#
7386
# with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
7487
# cli_main()
75-
#
76-
#
88+
89+
90+
# TODO
7791
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
7892
# @pytest.mark.parametrize('cli_args', [ddp_args])
7993
# def test_examples_ddp_autoencoder(cli_args):
@@ -92,3 +106,14 @@ def test_examples_cpu(cli_args):
92106
for cli_cmd in [mnist_cli, ic_cli, ae_cli]:
93107
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
94108
cli_cmd()
109+
110+
111+
@pytest.mark.skipif(not DALI_AVAILABLE, reason="Nvidia DALI required")
112+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
113+
@pytest.mark.skipif(platform.system() != 'Linux', reason='Only applies to Linux platform.')
114+
@pytest.mark.parametrize('cli_args', [cpu_args])
115+
def test_examples_mnist_dali(cli_args):
116+
from pl_examples.basic_examples.mnist_dali import cli_main
117+
118+
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
119+
cli_main()

requirements/examples.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
torchvision>=0.4.1,<0.9.0
2-
gym>=0.17.0
2+
gym>=0.17.0

0 commit comments

Comments
 (0)