Skip to content

Commit aec682d

Browse files
gan3sh500Borda
andcommitted
update DALIClassificationLoader to not use deprecated arguments (#4925)
* update DALIClassificationLoader to not use deprecated arguments * fix line length * dali version check added and changed args accordingly * versions * checking version using disutils.version.LooseVersion now * . * ver * import Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent be2bdd0 commit aec682d

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

.drone.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ steps:
3636
- pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir
3737
- pip install git+https://${AUTH_TOKEN}@github.com/PyTorchLightning/[email protected] -v --no-cache-dir
3838
# when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0"
39-
# todo: temprarl fix till https://github.com/PyTorchLightning/pytorch-lightning/pull/4922 is resolved
40-
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist "nvidia-dali-cuda100<0.27" --upgrade-strategy only-if-needed
39+
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed
4140
- pip list
4241
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
4342
# Running special tests

pl_examples/basic_examples/dali_image_classifier.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from argparse import ArgumentParser
1616
from random import shuffle
1717
from warnings import warn
18+
from distutils.version import LooseVersion
1819

1920
import numpy as np
2021
import torch
@@ -31,12 +32,17 @@
3132
from tests.base.datasets import MNIST
3233

3334
if DALI_AVAILABLE:
34-
import nvidia.dali.ops as ops
35+
from nvidia.dali import ops
3536
from nvidia.dali.pipeline import Pipeline
3637
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
38+
from nvidia.dali import __version__ as dali_version
39+
40+
NEW_DALI_API = LooseVersion(dali_version) >= LooseVersion('0.28.0')
41+
if NEW_DALI_API:
42+
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
3743
else:
3844
warn('NVIDIA DALI is not available')
39-
ops, Pipeline, DALIClassificationIterator = ..., ABC, ABC
45+
ops, Pipeline, DALIClassificationIterator, LastBatchPolicy = ..., ABC, ABC, ABC
4046

4147

4248
class ExternalMNISTInputIterator(object):
@@ -98,11 +104,18 @@ def __init__(
98104
dynamic_shape=False,
99105
last_batch_padded=False,
100106
):
101-
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded)
107+
if NEW_DALI_API:
108+
last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP
109+
super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape,
110+
last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded)
111+
else:
112+
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch,
113+
dynamic_shape, last_batch_padded)
114+
self._fill_last_batch = fill_last_batch
102115

103116
def __len__(self):
104117
batch_count = self._size // (self._num_gpus * self.batch_size)
105-
last_batch = 1 if self._fill_last_batch else 0
118+
last_batch = 1 if self._fill_last_batch else 1
106119
return batch_count + last_batch
107120

108121

@@ -179,7 +192,7 @@ def cli_main():
179192
eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size)
180193

181194
pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0)
182-
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False)
195+
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=True)
183196

184197
pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0)
185198
val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False)

0 commit comments

Comments
 (0)