Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ steps:
- pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir
- pip install git+https://${AUTH_TOKEN}@github.com/PyTorchLightning/[email protected] -v --no-cache-dir
# when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0"
# todo: temprarl fix till https://github.com/PyTorchLightning/pytorch-lightning/pull/4922 is resolved
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist "nvidia-dali-cuda100<0.27" --upgrade-strategy only-if-needed
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed
- pip list
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
# Running special tests
Expand Down
23 changes: 18 additions & 5 deletions pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from argparse import ArgumentParser
from random import shuffle
from warnings import warn
from distutils.version import LooseVersion

import numpy as np
import torch
Expand All @@ -31,12 +32,17 @@
from tests.base.datasets import MNIST

if DALI_AVAILABLE:
import nvidia.dali.ops as ops
from nvidia.dali import ops
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali import __version__ as dali_version

NEW_DALI_API = LooseVersion(dali_version) >= LooseVersion('0.28.0')
if NEW_DALI_API:
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
else:
warn('NVIDIA DALI is not available')
ops, Pipeline, DALIClassificationIterator = ..., ABC, ABC
ops, Pipeline, DALIClassificationIterator, LastBatchPolicy = ..., ABC, ABC, ABC


class ExternalMNISTInputIterator(object):
Expand Down Expand Up @@ -97,11 +103,18 @@ def __init__(
dynamic_shape=False,
last_batch_padded=False,
):
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded)
if NEW_DALI_API:
last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP
super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape,
last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded)
else:
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch,
dynamic_shape, last_batch_padded)
self._fill_last_batch = fill_last_batch

def __len__(self):
batch_count = self._size // (self._num_gpus * self.batch_size)
last_batch = 1 if self._fill_last_batch else 0
last_batch = 1 if self._fill_last_batch else 1
return batch_count + last_batch


Expand Down Expand Up @@ -178,7 +191,7 @@ def cli_main():
eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size)

pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0)
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False)
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=True)

pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0)
val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False)
Expand Down