Skip to content

Commit 16e819e

Browse files
gan3sh500Borda
andauthored
update DALIClassificationLoader to not use deprecated arguments (#4925)
* update DALIClassificationLoader to not use deprecated arguments * fix line length * dali version check added and changed args accordingly * versions * checking version using disutils.version.LooseVersion now * . * ver * import Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 81070be commit 16e819e

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

.drone.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ steps:
3636
- pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir
3737
- pip install git+https://${AUTH_TOKEN}@github.com/PyTorchLightning/[email protected] -v --no-cache-dir
3838
# when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0"
39-
# todo: temprarl fix till https://github.com/PyTorchLightning/pytorch-lightning/pull/4922 is resolved
40-
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist "nvidia-dali-cuda100<0.27" --upgrade-strategy only-if-needed
39+
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed
4140
- pip list
4241
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
4342
# Running special tests

pl_examples/basic_examples/dali_image_classifier.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from argparse import ArgumentParser
1616
from random import shuffle
1717
from warnings import warn
18+
from distutils.version import LooseVersion
1819

1920
import numpy as np
2021
import torch
@@ -31,12 +32,17 @@
3132
from tests.base.datasets import MNIST
3233

3334
if DALI_AVAILABLE:
34-
import nvidia.dali.ops as ops
35+
from nvidia.dali import ops
3536
from nvidia.dali.pipeline import Pipeline
3637
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
38+
from nvidia.dali import __version__ as dali_version
39+
40+
NEW_DALI_API = LooseVersion(dali_version) >= LooseVersion('0.28.0')
41+
if NEW_DALI_API:
42+
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
3743
else:
3844
warn('NVIDIA DALI is not available')
39-
ops, Pipeline, DALIClassificationIterator = ..., ABC, ABC
45+
ops, Pipeline, DALIClassificationIterator, LastBatchPolicy = ..., ABC, ABC, ABC
4046

4147

4248
class ExternalMNISTInputIterator(object):
@@ -97,11 +103,18 @@ def __init__(
97103
dynamic_shape=False,
98104
last_batch_padded=False,
99105
):
100-
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded)
106+
if NEW_DALI_API:
107+
last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP
108+
super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape,
109+
last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded)
110+
else:
111+
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch,
112+
dynamic_shape, last_batch_padded)
113+
self._fill_last_batch = fill_last_batch
101114

102115
def __len__(self):
103116
batch_count = self._size // (self._num_gpus * self.batch_size)
104-
last_batch = 1 if self._fill_last_batch else 0
117+
last_batch = 1 if self._fill_last_batch else 1
105118
return batch_count + last_batch
106119

107120

@@ -178,7 +191,7 @@ def cli_main():
178191
eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size)
179192

180193
pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0)
181-
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False)
194+
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=True)
182195

183196
pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0)
184197
val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False)

0 commit comments

Comments
 (0)