|
15 | 15 | from argparse import ArgumentParser |
16 | 16 | from random import shuffle |
17 | 17 | from warnings import warn |
| 18 | +from distutils.version import LooseVersion |
18 | 19 |
|
19 | 20 | import numpy as np |
20 | 21 | import torch |
|
31 | 32 | from tests.base.datasets import MNIST |
32 | 33 |
|
33 | 34 | if DALI_AVAILABLE: |
34 | | - import nvidia.dali.ops as ops |
| 35 | + from nvidia.dali import ops |
35 | 36 | from nvidia.dali.pipeline import Pipeline |
36 | 37 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator |
| 38 | + from nvidia.dali import __version__ as dali_version |
| 39 | + |
| 40 | + NEW_DALI_API = LooseVersion(dali_version) >= LooseVersion('0.28.0') |
| 41 | + if NEW_DALI_API: |
| 42 | + from nvidia.dali.plugin.base_iterator import LastBatchPolicy |
37 | 43 | else: |
38 | 44 | warn('NVIDIA DALI is not available') |
39 | | - ops, Pipeline, DALIClassificationIterator = ..., ABC, ABC |
| 45 | + ops, Pipeline, DALIClassificationIterator, LastBatchPolicy = ..., ABC, ABC, ABC |
40 | 46 |
|
41 | 47 |
|
42 | 48 | class ExternalMNISTInputIterator(object): |
@@ -97,11 +103,18 @@ def __init__( |
97 | 103 | dynamic_shape=False, |
98 | 104 | last_batch_padded=False, |
99 | 105 | ): |
100 | | - super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded) |
| 106 | + if NEW_DALI_API: |
| 107 | + last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP |
| 108 | + super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape, |
| 109 | + last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded) |
| 110 | + else: |
| 111 | + super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, |
| 112 | + dynamic_shape, last_batch_padded) |
| 113 | + self._fill_last_batch = fill_last_batch |
101 | 114 |
|
102 | 115 | def __len__(self): |
103 | 116 | batch_count = self._size // (self._num_gpus * self.batch_size) |
104 | | - last_batch = 1 if self._fill_last_batch else 0 |
| 117 | + last_batch = 1 if self._fill_last_batch else 1 |
105 | 118 | return batch_count + last_batch |
106 | 119 |
|
107 | 120 |
|
@@ -178,7 +191,7 @@ def cli_main(): |
178 | 191 | eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size) |
179 | 192 |
|
180 | 193 | pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0) |
181 | | - train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False) |
| 194 | + train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=True) |
182 | 195 |
|
183 | 196 | pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0) |
184 | 197 | val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False) |
|
0 commit comments