Skip to content

Commit bb4ce5f

Browse files
parmeetfacebook-github-bot
authored andcommitted
[fbsync] Enable custom default config for dataset tests (#3578)
Reviewed By: fmassa Differential Revision: D27433915 fbshipit-source-id: 70d5fcd0a8b68c2de7362ddf4f63a072cf658d7c
1 parent 23f2421 commit bb4ce5f

File tree

2 files changed

+145
-49
lines changed

2 files changed

+145
-49
lines changed

test/datasets_utils.py

Lines changed: 130 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,45 @@ def inner_wrapper(*args, **kwargs):
117117
def test_all_configs(test):
118118
"""Decorator to run test against all configurations.
119119
120-
Add this as decorator to an arbitrary test to run it against all configurations. The current configuration is
121-
provided as the first parameter:
120+
Add this as decorator to an arbitrary test to run it against all configurations. This includes
121+
:attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.
122+
123+
The current configuration is provided as the first parameter for the test:
122124
123125
.. code-block::
124126
125-
@test_all_configs
127+
@test_all_configs()
126128
def test_foo(self, config):
127129
pass
130+
131+
.. note::
132+
133+
This will try to remove duplicate configurations. During this process it will not not preserve a potential
134+
ordering of the configurations or an inner ordering of a configuration.
128135
"""
129136

137+
def maybe_remove_duplicates(configs):
138+
try:
139+
return [dict(config_) for config_ in set(tuple(sorted(config.items())) for config in configs)]
140+
except TypeError:
141+
# A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
142+
# removal would be a lot more elaborate and we simply bail out.
143+
return configs
144+
130145
@functools.wraps(test)
131146
def wrapper(self):
132-
for config in self.CONFIGS or (self._DEFAULT_CONFIG,):
147+
configs = []
148+
if self.DEFAULT_CONFIG is not None:
149+
configs.append(self.DEFAULT_CONFIG)
150+
if self.ADDITIONAL_CONFIGS is not None:
151+
configs.extend(self.ADDITIONAL_CONFIGS)
152+
153+
if not configs:
154+
configs = [self._KWARG_DEFAULTS.copy()]
155+
else:
156+
configs = maybe_remove_duplicates(configs)
157+
158+
for config in configs:
133159
with self.subTest(**config):
134160
test(self, config)
135161

@@ -166,9 +192,13 @@ class DatasetTestCase(unittest.TestCase):
166192
167193
Optionally, you can overwrite the following class attributes:
168194
169-
- CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictonary can contain an
170-
arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
171-
``transforms``, or ``download``. The first element will be used as default configuration.
195+
- DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
196+
keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
197+
``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
198+
not provide one.
199+
- ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
200+
contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
201+
``transforms``, or ``download``.
172202
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
173203
available, the tests are skipped.
174204
@@ -218,22 +248,31 @@ def test_baz(self):
218248
DATASET_CLASS = None
219249
FEATURE_TYPES = None
220250

221-
CONFIGS = None
251+
DEFAULT_CONFIG = None
252+
ADDITIONAL_CONFIGS = None
222253
REQUIRED_PACKAGES = None
223254

224-
_DEFAULT_CONFIG = None
225-
255+
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
226256
_TRANSFORM_KWARGS = {
227257
"transform",
228258
"target_transform",
229259
"transforms",
230260
}
261+
# These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
231262
_SPECIAL_KWARGS = {
232263
*_TRANSFORM_KWARGS,
233264
"download",
234265
}
266+
267+
# These fields are populated during setupClass() within _populate_private_class_attributes()
268+
269+
# This will be a dictionary containing all keyword arguments with their respective default values extracted from
270+
# the dataset constructor.
271+
_KWARG_DEFAULTS = None
272+
# This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
235273
_HAS_SPECIAL_KWARG = None
236274

275+
# These functions are disabled during dataset creation in create_dataset().
237276
_CHECK_FUNCTIONS = {
238277
"check_md5",
239278
"check_integrity",
@@ -256,7 +295,8 @@ def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]:
256295
Args:
257296
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
258297
to be created and in turn also for the fake data injected here.
259-
config (Dict[str, Any]): Configuration that will be used to create the dataset.
298+
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
299+
fields for all dataset parameters with default values.
260300
261301
Returns:
262302
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
@@ -273,7 +313,8 @@ def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Di
273313
Args:
274314
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
275315
to be created and in turn also for the fake data injected here.
276-
config (Dict[str, Any]): Configuration that will be used to create the dataset.
316+
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
317+
fields for all dataset parameters with default values.
277318
278319
Needs to return one of the following:
279320
@@ -293,9 +334,16 @@ def create_dataset(
293334
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
294335
r"""Create the dataset in a temporary directory.
295336
337+
The configuration passed to the dataset is populated to contain at least all parameters with default values.
338+
For this the following order of precedence is used:
339+
340+
1. Parameters in :attr:`kwargs`.
341+
2. Configuration in :attr:`config`.
342+
3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
343+
4. Default parameters of the dataset.
344+
296345
Args:
297-
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the
298-
default configuration is used.
346+
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
299347
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
300348
creating the dataset.
301349
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
@@ -308,30 +356,33 @@ def create_dataset(
308356
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
309357
for details.
310358
"""
311-
default_config = self._DEFAULT_CONFIG.copy()
312-
if config is not None:
313-
default_config.update(config)
314-
config = default_config
315-
316359
if patch_checks is None:
317360
patch_checks = inject_fake_data
318361

319362
special_kwargs, other_kwargs = self._split_kwargs(kwargs)
363+
364+
complete_config = self._KWARG_DEFAULTS.copy()
365+
if self.DEFAULT_CONFIG:
366+
complete_config.update(self.DEFAULT_CONFIG)
367+
if config:
368+
complete_config.update(config)
369+
if other_kwargs:
370+
complete_config.update(other_kwargs)
371+
320372
if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
321373
# override download param to False param if its default is truthy
322374
special_kwargs["download"] = False
323-
config.update(other_kwargs)
324375

325376
patchers = self._patch_download_extract()
326377
if patch_checks:
327378
patchers.update(self._patch_checks())
328379

329380
with get_tmp_dir() as tmpdir:
330-
args = self.dataset_args(tmpdir, config)
331-
info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None
381+
args = self.dataset_args(tmpdir, complete_config)
382+
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
332383

333384
with self._maybe_apply_patches(patchers), disable_console_output():
334-
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)
385+
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
335386

336387
yield dataset, info
337388

@@ -357,26 +408,69 @@ def _verify_required_public_class_attributes(cls):
357408

358409
@classmethod
359410
def _populate_private_class_attributes(cls):
360-
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
411+
defaults = []
412+
for cls_ in cls.DATASET_CLASS.__mro__:
413+
if cls_ is torchvision.datasets.VisionDataset:
414+
break
361415

362-
cls._DEFAULT_CONFIG = {
363-
kwarg: default
364-
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
365-
if kwarg not in cls._SPECIAL_KWARGS
366-
}
416+
argspec = inspect.getfullargspec(cls_.__init__)
367417

368-
cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args}
418+
if not argspec.defaults:
419+
continue
420+
421+
defaults.append(
422+
{kwarg: default for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)}
423+
)
424+
425+
if not argspec.varkw:
426+
break
427+
428+
kwarg_defaults = dict()
429+
for config in reversed(defaults):
430+
kwarg_defaults.update(config)
431+
432+
has_special_kwargs = set()
433+
for name in cls._SPECIAL_KWARGS:
434+
if name not in kwarg_defaults:
435+
continue
436+
437+
del kwarg_defaults[name]
438+
has_special_kwargs.add(name)
439+
440+
cls._KWARG_DEFAULTS = kwarg_defaults
441+
cls._HAS_SPECIAL_KWARG = has_special_kwargs
369442

370443
@classmethod
371444
def _process_optional_public_class_attributes(cls):
372-
if cls.REQUIRED_PACKAGES is not None:
373-
try:
374-
for pkg in cls.REQUIRED_PACKAGES:
445+
def check_config(config, name):
446+
special_kwargs = tuple(f"'{name}'" for name in cls._SPECIAL_KWARGS if name in config)
447+
if special_kwargs:
448+
raise UsageError(
449+
f"{name} contains a value for the parameter(s) {', '.join(special_kwargs)}. "
450+
f"These are handled separately by the test case and should not be set here. "
451+
f"If you need to test some custom behavior regarding these parameters, "
452+
f"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
453+
)
454+
455+
if cls.DEFAULT_CONFIG is not None:
456+
check_config(cls.DEFAULT_CONFIG, "DEFAULT_CONFIG")
457+
458+
if cls.ADDITIONAL_CONFIGS is not None:
459+
for idx, config in enumerate(cls.ADDITIONAL_CONFIGS):
460+
check_config(config, f"CONFIGS[{idx}]")
461+
462+
if cls.REQUIRED_PACKAGES:
463+
missing_pkgs = []
464+
for pkg in cls.REQUIRED_PACKAGES:
465+
try:
375466
importlib.import_module(pkg)
376-
except ImportError as error:
467+
except ImportError:
468+
missing_pkgs.append(f"'{pkg}'")
469+
470+
if missing_pkgs:
377471
raise unittest.SkipTest(
378-
f"The package '{error.name}' is required to load the dataset '{cls.DATASET_CLASS.__name__}' but is "
379-
f"not installed."
472+
f"The package(s) {', '.join(missing_pkgs)} are required to load the dataset "
473+
f"'{cls.DATASET_CLASS.__name__}', but are not installed."
380474
)
381475

382476
def _split_kwargs(self, kwargs):

test/test_datasets.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
369369
DATASET_CLASS = datasets.Caltech101
370370
FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple))
371371

372-
CONFIGS = datasets_utils.combinations_grid(target_type=("category", "annotation", ["category", "annotation"]))
372+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
373+
target_type=("category", "annotation", ["category", "annotation"])
374+
)
373375
REQUIRED_PACKAGES = ("scipy",)
374376

375377
def inject_fake_data(self, tmpdir, config):
@@ -466,7 +468,7 @@ def inject_fake_data(self, tmpdir, config):
466468
class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
467469
DATASET_CLASS = datasets.WIDERFace
468470
FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target
469-
CONFIGS = datasets_utils.combinations_grid(split=('train', 'val', 'test'))
471+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val', 'test'))
470472

471473
def inject_fake_data(self, tmpdir, config):
472474
widerface_dir = pathlib.Path(tmpdir) / 'widerface'
@@ -521,7 +523,7 @@ def inject_fake_data(self, tmpdir, config):
521523
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
522524
DATASET_CLASS = datasets.ImageNet
523525
REQUIRED_PACKAGES = ('scipy',)
524-
CONFIGS = datasets_utils.combinations_grid(split=('train', 'val'))
526+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val'))
525527

526528
def inject_fake_data(self, tmpdir, config):
527529
tmpdir = pathlib.Path(tmpdir)
@@ -551,7 +553,7 @@ def inject_fake_data(self, tmpdir, config):
551553

552554
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
553555
DATASET_CLASS = datasets.CIFAR10
554-
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
556+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
555557

556558
_VERSION_CONFIG = dict(
557559
base_folder="cifar-10-batches-py",
@@ -623,7 +625,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
623625
DATASET_CLASS = datasets.CelebA
624626
FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None)))
625627

626-
CONFIGS = datasets_utils.combinations_grid(
628+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
627629
split=("train", "valid", "test", "all"),
628630
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
629631
)
@@ -740,7 +742,7 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
740742
DATASET_CLASS = datasets.VOCSegmentation
741743
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image)
742744

743-
CONFIGS = (
745+
ADDITIONAL_CONFIGS = (
744746
*datasets_utils.combinations_grid(
745747
year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")
746748
),
@@ -929,7 +931,7 @@ def test_captions(self):
929931
class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
930932
DATASET_CLASS = datasets.UCF101
931933

932-
CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
934+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
933935

934936
_VIDEO_FOLDER = "videos"
935937
_ANNOTATIONS_FOLDER = "annotations"
@@ -990,7 +992,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
990992
DATASET_CLASS = datasets.LSUN
991993

992994
REQUIRED_PACKAGES = ("lmdb",)
993-
CONFIGS = datasets_utils.combinations_grid(
995+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
994996
classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"])
995997
)
996998

@@ -1097,7 +1099,7 @@ def test_not_found_or_corrupted(self):
10971099
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
10981100
DATASET_CLASS = datasets.HMDB51
10991101

1100-
CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
1102+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
11011103

11021104
_VIDEO_FOLDER = "videos"
11031105
_SPLITS_FOLDER = "splits"
@@ -1157,7 +1159,7 @@ def _create_split_files(self, root, video_files, fold, train):
11571159
class OmniglotTestCase(datasets_utils.ImageDatasetTestCase):
11581160
DATASET_CLASS = datasets.Omniglot
11591161

1160-
CONFIGS = datasets_utils.combinations_grid(background=(True, False))
1162+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(background=(True, False))
11611163

11621164
def inject_fake_data(self, tmpdir, config):
11631165
target_folder = (
@@ -1237,7 +1239,7 @@ def inject_fake_data(self, tmpdir, config):
12371239
class USPSTestCase(datasets_utils.ImageDatasetTestCase):
12381240
DATASET_CLASS = datasets.USPS
12391241

1240-
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
1242+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
12411243

12421244
def inject_fake_data(self, tmpdir, config):
12431245
num_images = 2 if config["train"] else 1
@@ -1259,7 +1261,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
12591261

12601262
REQUIRED_PACKAGES = ("scipy.io", "scipy.sparse")
12611263

1262-
CONFIGS = datasets_utils.combinations_grid(
1264+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
12631265
image_set=("train", "val", "train_noval"), mode=("boundaries", "segmentation")
12641266
)
12651267

@@ -1345,7 +1347,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
13451347
_TRAIN_FEATURE_TYPES = (torch.Tensor,)
13461348
_TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor)
13471349

1348-
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
1350+
datasets_utils.combinations_grid(train=(True, False))
13491351

13501352
_NAME = "liberty"
13511353

0 commit comments

Comments
 (0)