Skip to content

Commit fd2bbe0

Browse files
committed
Merge branch 'master' into remove-old-cifar-tests
2 parents be8978b + 2e8c124 commit fd2bbe0

File tree

3 files changed

+158
-115
lines changed

3 files changed

+158
-115
lines changed

test/datasets_utils.py

Lines changed: 115 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import itertools
77
import os
88
import pathlib
9+
import random
10+
import string
911
import unittest
1012
import unittest.mock
1113
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
@@ -32,10 +34,11 @@
3234
"create_image_folder",
3335
"create_video_file",
3436
"create_video_folder",
37+
"create_random_string",
3538
]
3639

3740

38-
class UsageError(RuntimeError):
41+
class UsageError(Exception):
3942
"""Should be raised in case an error happens in the setup rather than the test."""
4043

4144

@@ -93,14 +96,6 @@ def inner_wrapper(*args, **kwargs):
9396
return outer_wrapper
9497

9598

96-
# As of Python 3.7 this is provided by contextlib
97-
# https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext
98-
# TODO: If the minimum Python requirement is >= 3.7, replace this
99-
@contextlib.contextmanager
100-
def nullcontext(enter_result=None):
101-
yield enter_result
102-
103-
10499
def test_all_configs(test):
105100
"""Decorator to run test against all configurations.
106101
@@ -116,7 +111,7 @@ def test_foo(self, config):
116111

117112
@functools.wraps(test)
118113
def wrapper(self):
119-
for config in self.CONFIGS:
114+
for config in self.CONFIGS or (self._DEFAULT_CONFIG,):
120115
with self.subTest(**config):
121116
test(self, config)
122117

@@ -165,7 +160,8 @@ class DatasetTestCase(unittest.TestCase):
165160
166161
Without further configuration, the testcase will test if
167162
168-
1. the dataset raises a ``RuntimeError`` if the data files are not found,
163+
1. the dataset raises a :class:`FileNotFoundError` or a :class:`RuntimeError` if the data files are not found or
164+
corrupted,
169165
2. the dataset inherits from `torchvision.datasets.VisionDataset`,
170166
3. the dataset can be turned into a string,
171167
4. the feature types of a returned example matches ``FEATURE_TYPES``,
@@ -206,6 +202,8 @@ def test_baz(self):
206202
CONFIGS = None
207203
REQUIRED_PACKAGES = None
208204

205+
_DEFAULT_CONFIG = None
206+
209207
_TRANSFORM_KWARGS = {
210208
"transform",
211209
"target_transform",
@@ -228,9 +226,25 @@ def test_baz(self):
228226
"download_and_extract_archive",
229227
}
230228

231-
def inject_fake_data(
232-
self, tmpdir: str, config: Dict[str, Any]
233-
) -> Union[int, Dict[str, Any], Tuple[Sequence[Any], Union[int, Dict[str, Any]]]]:
229+
def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]:
230+
"""Define positional arguments passed to the dataset.
231+
232+
.. note::
233+
234+
The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter.
235+
Otherwise you need to overwrite this method.
236+
237+
Args:
238+
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
239+
to be created and in turn also for the fake data injected here.
240+
config (Dict[str, Any]): Configuration that will be used to create the dataset.
241+
242+
Returns:
243+
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
244+
"""
245+
return (tmpdir,)
246+
247+
def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Dict[str, Any]]:
234248
"""Inject fake data for dataset into a temporary directory.
235249
236250
Args:
@@ -240,15 +254,9 @@ def inject_fake_data(
240254
241255
Needs to return one of the following:
242256
243-
1. (int): Number of examples in the dataset to be created,
257+
1. (int): Number of examples in the dataset to be created, or
244258
2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
245-
``"num_examples"`` that corresponds to the number of examples in the dataset to be created, or
246-
3. (Tuple[Sequence[Any], Union[int, Dict[str, Any]]]): Additional required parameters that are passed to
247-
the dataset constructor. The second element corresponds to cases 1. and 2.
248-
249-
If no ``args`` is returned (case 1. and 2.), the ``tmp_dir`` is passed as first parameter to the dataset
250-
constructor. In most cases this corresponds to ``root``. If the dataset has more parameters without default
251-
values you need to explicitly pass them as explained in case 3.
259+
``"num_examples"`` that corresponds to the number of examples in the dataset to be created.
252260
"""
253261
raise NotImplementedError("You need to provide fake data in order for the tests to run.")
254262

@@ -257,7 +265,7 @@ def create_dataset(
257265
self,
258266
config: Optional[Dict[str, Any]] = None,
259267
inject_fake_data: bool = True,
260-
disable_download_extract: Optional[bool] = None,
268+
patch_checks: Optional[bool] = None,
261269
**kwargs: Any,
262270
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
263271
r"""Create the dataset in a temporary directory.
@@ -267,8 +275,8 @@ def create_dataset(
267275
default configuration is used.
268276
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
269277
creating the dataset.
270-
disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating
271-
the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``.
278+
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
279+
omitted defaults to the same value as ``inject_fake_data``.
272280
**kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
273281
overlap with ``config``.
274282
@@ -277,46 +285,28 @@ def create_dataset(
277285
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
278286
for details.
279287
"""
280-
if config is None:
281-
config = self.CONFIGS[0].copy()
288+
default_config = self._DEFAULT_CONFIG.copy()
289+
if config is not None:
290+
default_config.update(config)
291+
config = default_config
292+
293+
if patch_checks is None:
294+
patch_checks = inject_fake_data
282295

283296
special_kwargs, other_kwargs = self._split_kwargs(kwargs)
297+
if "download" in self._HAS_SPECIAL_KWARG:
298+
special_kwargs["download"] = False
284299
config.update(other_kwargs)
285300

286-
if disable_download_extract is None:
287-
disable_download_extract = inject_fake_data
301+
patchers = self._patch_download_extract()
302+
if patch_checks:
303+
patchers.update(self._patch_checks())
288304

289305
with get_tmp_dir() as tmpdir:
290-
output = self.inject_fake_data(tmpdir, config) if inject_fake_data else None
291-
if output is None:
292-
raise UsageError(
293-
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
294-
"examples for the current configuration."
295-
)
306+
args = self.dataset_args(tmpdir, config)
307+
info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None
296308

297-
if isinstance(output, collections.abc.Sequence) and len(output) == 2:
298-
args, info = output
299-
else:
300-
args = (tmpdir,)
301-
info = output
302-
303-
if isinstance(info, int):
304-
info = dict(num_examples=info)
305-
elif isinstance(info, dict):
306-
if "num_examples" not in info:
307-
raise UsageError(
308-
"The information dictionary returned by the method 'inject_fake_data' must contain a "
309-
"'num_examples' field that holds the number of examples for the current configuration."
310-
)
311-
else:
312-
raise UsageError(
313-
f"The additional information returned by the method 'inject_fake_data' must be either an integer "
314-
f"indicating the number of examples for the current configuration or a dictionary with the the "
315-
f"same content. Got {type(info)} instead."
316-
)
317-
318-
cm = self._disable_download_extract if disable_download_extract else nullcontext
319-
with cm(special_kwargs), disable_console_output():
309+
with self._maybe_apply_patches(patchers), disable_console_output():
320310
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)
321311

322312
yield dataset, info
@@ -344,19 +334,17 @@ def _verify_required_public_class_attributes(cls):
344334
@classmethod
345335
def _populate_private_class_attributes(cls):
346336
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
337+
338+
cls._DEFAULT_CONFIG = {
339+
kwarg: default
340+
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
341+
if kwarg not in cls._SPECIAL_KWARGS
342+
}
343+
347344
cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args}
348345

349346
@classmethod
350347
def _process_optional_public_class_attributes(cls):
351-
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
352-
if cls.CONFIGS is None:
353-
config = {
354-
kwarg: default
355-
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
356-
if kwarg not in cls._SPECIAL_KWARGS
357-
}
358-
cls.CONFIGS = (config,)
359-
360348
if cls.REQUIRED_PACKAGES is not None:
361349
try:
362350
for pkg in cls.REQUIRED_PACKAGES:
@@ -372,31 +360,47 @@ def _split_kwargs(self, kwargs):
372360
other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
373361
return special_kwargs, other_kwargs
374362

375-
@contextlib.contextmanager
376-
def _disable_download_extract(self, special_kwargs):
377-
inject_download_kwarg = "download" in self._HAS_SPECIAL_KWARG and "download" not in special_kwargs
378-
if inject_download_kwarg:
379-
special_kwargs["download"] = False
363+
def _inject_fake_data(self, tmpdir, config):
364+
info = self.inject_fake_data(tmpdir, config)
365+
if info is None:
366+
raise UsageError(
367+
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
368+
"examples for the current configuration."
369+
)
370+
elif isinstance(info, int):
371+
info = dict(num_examples=info)
372+
elif not isinstance(info, dict):
373+
raise UsageError(
374+
f"The additional information returned by the method 'inject_fake_data' must be either an "
375+
f"integer indicating the number of examples for the current configuration or a dictionary with "
376+
f"the same content. Got {type(info)} instead."
377+
)
378+
elif "num_examples" not in info:
379+
raise UsageError(
380+
"The information dictionary returned by the method 'inject_fake_data' must contain a "
381+
"'num_examples' field that holds the number of examples for the current configuration."
382+
)
383+
return info
384+
385+
def _patch_download_extract(self):
386+
module = inspect.getmodule(self.DATASET_CLASS).__name__
387+
return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS}
380388

389+
def _patch_checks(self):
381390
module = inspect.getmodule(self.DATASET_CLASS).__name__
391+
return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS}
392+
393+
@contextlib.contextmanager
394+
def _maybe_apply_patches(self, patchers):
382395
with contextlib.ExitStack() as stack:
383396
mocks = {}
384-
for function, kwargs in itertools.chain(
385-
zip(self._CHECK_FUNCTIONS, [dict(return_value=True)] * len(self._CHECK_FUNCTIONS)),
386-
zip(self._DOWNLOAD_EXTRACT_FUNCTIONS, [dict()] * len(self._DOWNLOAD_EXTRACT_FUNCTIONS)),
387-
):
397+
for patcher in patchers:
388398
with contextlib.suppress(AttributeError):
389-
patcher = unittest.mock.patch(f"{module}.{function}", **kwargs)
390-
mocks[function] = stack.enter_context(patcher)
391-
392-
try:
393-
yield mocks
394-
finally:
395-
if inject_download_kwarg:
396-
del special_kwargs["download"]
399+
mocks[patcher.target] = stack.enter_context(patcher)
400+
yield mocks
397401

398-
def test_not_found(self):
399-
with self.assertRaises(RuntimeError):
402+
def test_not_found_or_corrupted(self):
403+
with self.assertRaises((FileNotFoundError, RuntimeError)):
400404
with self.create_dataset(inject_fake_data=False):
401405
pass
402406

@@ -461,13 +465,13 @@ def create_dataset(
461465
self,
462466
config: Optional[Dict[str, Any]] = None,
463467
inject_fake_data: bool = True,
464-
disable_download_extract: Optional[bool] = None,
468+
patch_checks: Optional[bool] = None,
465469
**kwargs: Any,
466470
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
467471
with super().create_dataset(
468472
config=config,
469473
inject_fake_data=inject_fake_data,
470-
disable_download_extract=disable_download_extract,
474+
patch_checks=patch_checks,
471475
**kwargs,
472476
) as (dataset, info):
473477
# PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
@@ -511,26 +515,20 @@ class VideoDatasetTestCase(DatasetTestCase):
511515

512516
def __init__(self, *args, **kwargs):
513517
super().__init__(*args, **kwargs)
514-
self.inject_fake_data = self._set_default_frames_per_clip(self.inject_fake_data)
518+
self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)
515519

516520
def _set_default_frames_per_clip(self, inject_fake_data):
517521
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
518522
args_without_default = argspec.args[1:-len(argspec.defaults)]
519523
frames_per_clip_last = args_without_default[-1] == "frames_per_clip"
520-
only_root_and_frames_per_clip = (len(args_without_default) == 2) and frames_per_clip_last
521524

522525
@functools.wraps(inject_fake_data)
523526
def wrapper(tmpdir, config):
524-
output = inject_fake_data(tmpdir, config)
525-
if isinstance(output, collections.abc.Sequence) and len(output) == 2:
526-
args, info = output
527-
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
528-
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)
529-
return args, info
530-
elif isinstance(output, (int, dict)) and only_root_and_frames_per_clip:
531-
return (tmpdir, self.DEFAULT_FRAMES_PER_CLIP)
532-
else:
533-
return output
527+
args = inject_fake_data(tmpdir, config)
528+
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
529+
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)
530+
531+
return args
534532

535533
return wrapper
536534

@@ -570,7 +568,7 @@ def create_image_file(
570568

571569
image = create_image_or_video_tensor(size)
572570
file = pathlib.Path(root) / name
573-
PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file)
571+
PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file, **kwargs)
574572
return file
575573

576574

@@ -706,6 +704,21 @@ def size(idx):
706704
os.makedirs(root)
707705

708706
return [
709-
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size)
707+
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
710708
for idx in range(num_examples)
711709
]
710+
711+
712+
def create_random_string(length: int, *digits: str) -> str:
713+
"""Create a random string.
714+
715+
Args:
716+
length (int): Number of characters in the generated string.
717+
*characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
718+
"""
719+
if not digits:
720+
digits = string.ascii_lowercase
721+
else:
722+
digits = "".join(itertools.chain(*digits))
723+
724+
return "".join(random.choice(digits) for _ in range(length))

0 commit comments

Comments
 (0)