diff --git a/.drone.yml b/.drone.yml index c87130844c040..1041ebdf872c8 100644 --- a/.drone.yml +++ b/.drone.yml @@ -36,8 +36,7 @@ steps: - pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir - pip install git+https://${AUTH_TOKEN}@github.com/PyTorchLightning/lightning-dtrun.git@v0.0.2 -v --no-cache-dir # when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0" - # todo: temprarl fix till https://github.com/PyTorchLightning/pytorch-lightning/pull/4922 is resolved - - pip install --extra-index-url https://developer.download.nvidia.com/compute/redist "nvidia-dali-cuda100<0.27" --upgrade-strategy only-if-needed + - pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed - pip list - python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8 # Running special tests diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index 9f3ba5e08b37e..e628f5daf8a53 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -15,6 +15,7 @@ from argparse import ArgumentParser from random import shuffle from warnings import warn +from distutils.version import LooseVersion import numpy as np import torch @@ -31,12 +32,17 @@ from tests.base.datasets import MNIST if DALI_AVAILABLE: - import nvidia.dali.ops as ops + from nvidia.dali import ops from nvidia.dali.pipeline import Pipeline from nvidia.dali.plugin.pytorch import DALIClassificationIterator + from nvidia.dali import __version__ as dali_version + + NEW_DALI_API = LooseVersion(dali_version) >= LooseVersion('0.28.0') + if NEW_DALI_API: + from nvidia.dali.plugin.base_iterator import LastBatchPolicy else: warn('NVIDIA DALI is not available') - ops, Pipeline, DALIClassificationIterator = ..., ABC, ABC + ops, Pipeline, DALIClassificationIterator, LastBatchPolicy = ..., ABC, ABC, ABC class ExternalMNISTInputIterator(object): @@ -97,11 +103,18 @@ def __init__( dynamic_shape=False, last_batch_padded=False, ): - super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded) + if NEW_DALI_API: + last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP + super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape, + last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded) + else: + super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, + dynamic_shape, last_batch_padded) + self._fill_last_batch = fill_last_batch def __len__(self): batch_count = self._size // (self._num_gpus * self.batch_size) - last_batch = 1 if self._fill_last_batch else 0 + last_batch = 1 if self._fill_last_batch else 1 return batch_count + last_batch @@ -178,7 +191,7 @@ def cli_main(): eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size) pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0) - train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False) + train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=True) pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0) val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False)