66import itertools
77import os
88import pathlib
9+ import random
10+ import string
911import unittest
1012import unittest .mock
1113from typing import Any , Callable , Dict , Iterator , List , Optional , Sequence , Tuple , Union
3234 "create_image_folder" ,
3335 "create_video_file" ,
3436 "create_video_folder" ,
37+ "create_random_string" ,
3538]
3639
3740
38- class UsageError (RuntimeError ):
41+ class UsageError (Exception ):
3942 """Should be raised in case an error happens in the setup rather than the test."""
4043
4144
@@ -93,14 +96,6 @@ def inner_wrapper(*args, **kwargs):
9396 return outer_wrapper
9497
9598
96- # As of Python 3.7 this is provided by contextlib
97- # https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext
98- # TODO: If the minimum Python requirement is >= 3.7, replace this
99- @contextlib .contextmanager
100- def nullcontext (enter_result = None ):
101- yield enter_result
102-
103-
10499def test_all_configs (test ):
105100 """Decorator to run test against all configurations.
106101
@@ -116,7 +111,7 @@ def test_foo(self, config):
116111
117112 @functools .wraps (test )
118113 def wrapper (self ):
119- for config in self .CONFIGS :
114+ for config in self .CONFIGS or ( self . _DEFAULT_CONFIG ,) :
120115 with self .subTest (** config ):
121116 test (self , config )
122117
@@ -165,7 +160,8 @@ class DatasetTestCase(unittest.TestCase):
165160
166161 Without further configuration, the testcase will test if
167162
168- 1. the dataset raises a ``RuntimeError`` if the data files are not found,
163+ 1. the dataset raises a :class:`FileNotFoundError` or a :class:`RuntimeError` if the data files are not found or
164+ corrupted,
169165 2. the dataset inherits from `torchvision.datasets.VisionDataset`,
170166 3. the dataset can be turned into a string,
171167 4. the feature types of a returned example matches ``FEATURE_TYPES``,
@@ -206,6 +202,8 @@ def test_baz(self):
206202 CONFIGS = None
207203 REQUIRED_PACKAGES = None
208204
205+ _DEFAULT_CONFIG = None
206+
209207 _TRANSFORM_KWARGS = {
210208 "transform" ,
211209 "target_transform" ,
@@ -228,9 +226,25 @@ def test_baz(self):
228226 "download_and_extract_archive" ,
229227 }
230228
231- def inject_fake_data (
232- self , tmpdir : str , config : Dict [str , Any ]
233- ) -> Union [int , Dict [str , Any ], Tuple [Sequence [Any ], Union [int , Dict [str , Any ]]]]:
229+ def dataset_args (self , tmpdir : str , config : Dict [str , Any ]) -> Sequence [Any ]:
230+ """Define positional arguments passed to the dataset.
231+
232+ .. note::
233+
234+ The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter.
235+ Otherwise you need to overwrite this method.
236+
237+ Args:
238+ tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
239+ to be created and in turn also for the fake data injected here.
240+ config (Dict[str, Any]): Configuration that will be used to create the dataset.
241+
242+ Returns:
243+ (Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
244+ """
245+ return (tmpdir ,)
246+
247+ def inject_fake_data (self , tmpdir : str , config : Dict [str , Any ]) -> Union [int , Dict [str , Any ]]:
234248 """Inject fake data for dataset into a temporary directory.
235249
236250 Args:
@@ -240,15 +254,9 @@ def inject_fake_data(
240254
241255 Needs to return one of the following:
242256
243- 1. (int): Number of examples in the dataset to be created,
257+ 1. (int): Number of examples in the dataset to be created, or
244258 2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
245- ``"num_examples"`` that corresponds to the number of examples in the dataset to be created, or
246- 3. (Tuple[Sequence[Any], Union[int, Dict[str, Any]]]): Additional required parameters that are passed to
247- the dataset constructor. The second element corresponds to cases 1. and 2.
248-
249- If no ``args`` is returned (case 1. and 2.), the ``tmp_dir`` is passed as first parameter to the dataset
250- constructor. In most cases this corresponds to ``root``. If the dataset has more parameters without default
251- values you need to explicitly pass them as explained in case 3.
259+ ``"num_examples"`` that corresponds to the number of examples in the dataset to be created.
252260 """
253261 raise NotImplementedError ("You need to provide fake data in order for the tests to run." )
254262
@@ -257,7 +265,7 @@ def create_dataset(
257265 self ,
258266 config : Optional [Dict [str , Any ]] = None ,
259267 inject_fake_data : bool = True ,
260- disable_download_extract : Optional [bool ] = None ,
268+ patch_checks : Optional [bool ] = None ,
261269 ** kwargs : Any ,
262270 ) -> Iterator [Tuple [torchvision .datasets .VisionDataset , Dict [str , Any ]]]:
263271 r"""Create the dataset in a temporary directory.
@@ -267,8 +275,8 @@ def create_dataset(
267275 default configuration is used.
268276 inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
269277 creating the dataset.
270- disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating
271- the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``.
278+ patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
279+ omitted defaults to the same value as ``inject_fake_data``.
272280 **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
273281 overlap with ``config``.
274282
@@ -277,46 +285,28 @@ def create_dataset(
277285 info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
278286 for details.
279287 """
280- if config is None :
281- config = self .CONFIGS [0 ].copy ()
288+ default_config = self ._DEFAULT_CONFIG .copy ()
289+ if config is not None :
290+ default_config .update (config )
291+ config = default_config
292+
293+ if patch_checks is None :
294+ patch_checks = inject_fake_data
282295
283296 special_kwargs , other_kwargs = self ._split_kwargs (kwargs )
297+ if "download" in self ._HAS_SPECIAL_KWARG :
298+ special_kwargs ["download" ] = False
284299 config .update (other_kwargs )
285300
286- if disable_download_extract is None :
287- disable_download_extract = inject_fake_data
301+ patchers = self ._patch_download_extract ()
302+ if patch_checks :
303+ patchers .update (self ._patch_checks ())
288304
289305 with get_tmp_dir () as tmpdir :
290- output = self .inject_fake_data (tmpdir , config ) if inject_fake_data else None
291- if output is None :
292- raise UsageError (
293- "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
294- "examples for the current configuration."
295- )
306+ args = self .dataset_args (tmpdir , config )
307+ info = self ._inject_fake_data (tmpdir , config ) if inject_fake_data else None
296308
297- if isinstance (output , collections .abc .Sequence ) and len (output ) == 2 :
298- args , info = output
299- else :
300- args = (tmpdir ,)
301- info = output
302-
303- if isinstance (info , int ):
304- info = dict (num_examples = info )
305- elif isinstance (info , dict ):
306- if "num_examples" not in info :
307- raise UsageError (
308- "The information dictionary returned by the method 'inject_fake_data' must contain a "
309- "'num_examples' field that holds the number of examples for the current configuration."
310- )
311- else :
312- raise UsageError (
313- f"The additional information returned by the method 'inject_fake_data' must be either an integer "
314- f"indicating the number of examples for the current configuration or a dictionary with the the "
315- f"same content. Got { type (info )} instead."
316- )
317-
318- cm = self ._disable_download_extract if disable_download_extract else nullcontext
319- with cm (special_kwargs ), disable_console_output ():
309+ with self ._maybe_apply_patches (patchers ), disable_console_output ():
320310 dataset = self .DATASET_CLASS (* args , ** config , ** special_kwargs )
321311
322312 yield dataset , info
@@ -344,19 +334,17 @@ def _verify_required_public_class_attributes(cls):
344334 @classmethod
345335 def _populate_private_class_attributes (cls ):
346336 argspec = inspect .getfullargspec (cls .DATASET_CLASS .__init__ )
337+
338+ cls ._DEFAULT_CONFIG = {
339+ kwarg : default
340+ for kwarg , default in zip (argspec .args [- len (argspec .defaults ):], argspec .defaults )
341+ if kwarg not in cls ._SPECIAL_KWARGS
342+ }
343+
347344 cls ._HAS_SPECIAL_KWARG = {name for name in cls ._SPECIAL_KWARGS if name in argspec .args }
348345
349346 @classmethod
350347 def _process_optional_public_class_attributes (cls ):
351- argspec = inspect .getfullargspec (cls .DATASET_CLASS .__init__ )
352- if cls .CONFIGS is None :
353- config = {
354- kwarg : default
355- for kwarg , default in zip (argspec .args [- len (argspec .defaults ):], argspec .defaults )
356- if kwarg not in cls ._SPECIAL_KWARGS
357- }
358- cls .CONFIGS = (config ,)
359-
360348 if cls .REQUIRED_PACKAGES is not None :
361349 try :
362350 for pkg in cls .REQUIRED_PACKAGES :
@@ -372,31 +360,47 @@ def _split_kwargs(self, kwargs):
372360 other_kwargs = {key : special_kwargs .pop (key ) for key in set (special_kwargs .keys ()) - self ._SPECIAL_KWARGS }
373361 return special_kwargs , other_kwargs
374362
375- @contextlib .contextmanager
376- def _disable_download_extract (self , special_kwargs ):
377- inject_download_kwarg = "download" in self ._HAS_SPECIAL_KWARG and "download" not in special_kwargs
378- if inject_download_kwarg :
379- special_kwargs ["download" ] = False
363+ def _inject_fake_data (self , tmpdir , config ):
364+ info = self .inject_fake_data (tmpdir , config )
365+ if info is None :
366+ raise UsageError (
367+ "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
368+ "examples for the current configuration."
369+ )
370+ elif isinstance (info , int ):
371+ info = dict (num_examples = info )
372+ elif not isinstance (info , dict ):
373+ raise UsageError (
374+ f"The additional information returned by the method 'inject_fake_data' must be either an "
375+ f"integer indicating the number of examples for the current configuration or a dictionary with "
376+ f"the same content. Got { type (info )} instead."
377+ )
378+ elif "num_examples" not in info :
379+ raise UsageError (
380+ "The information dictionary returned by the method 'inject_fake_data' must contain a "
381+ "'num_examples' field that holds the number of examples for the current configuration."
382+ )
383+ return info
384+
385+ def _patch_download_extract (self ):
386+ module = inspect .getmodule (self .DATASET_CLASS ).__name__
387+ return {unittest .mock .patch (f"{ module } .{ function } " ) for function in self ._DOWNLOAD_EXTRACT_FUNCTIONS }
380388
389+ def _patch_checks (self ):
381390 module = inspect .getmodule (self .DATASET_CLASS ).__name__
391+ return {unittest .mock .patch (f"{ module } .{ function } " , return_value = True ) for function in self ._CHECK_FUNCTIONS }
392+
393+ @contextlib .contextmanager
394+ def _maybe_apply_patches (self , patchers ):
382395 with contextlib .ExitStack () as stack :
383396 mocks = {}
384- for function , kwargs in itertools .chain (
385- zip (self ._CHECK_FUNCTIONS , [dict (return_value = True )] * len (self ._CHECK_FUNCTIONS )),
386- zip (self ._DOWNLOAD_EXTRACT_FUNCTIONS , [dict ()] * len (self ._DOWNLOAD_EXTRACT_FUNCTIONS )),
387- ):
397+ for patcher in patchers :
388398 with contextlib .suppress (AttributeError ):
389- patcher = unittest .mock .patch (f"{ module } .{ function } " , ** kwargs )
390- mocks [function ] = stack .enter_context (patcher )
391-
392- try :
393- yield mocks
394- finally :
395- if inject_download_kwarg :
396- del special_kwargs ["download" ]
399+ mocks [patcher .target ] = stack .enter_context (patcher )
400+ yield mocks
397401
398- def test_not_found (self ):
399- with self .assertRaises (RuntimeError ):
402+ def test_not_found_or_corrupted (self ):
403+ with self .assertRaises (( FileNotFoundError , RuntimeError ) ):
400404 with self .create_dataset (inject_fake_data = False ):
401405 pass
402406
@@ -461,13 +465,13 @@ def create_dataset(
461465 self ,
462466 config : Optional [Dict [str , Any ]] = None ,
463467 inject_fake_data : bool = True ,
464- disable_download_extract : Optional [bool ] = None ,
468+ patch_checks : Optional [bool ] = None ,
465469 ** kwargs : Any ,
466470 ) -> Iterator [Tuple [torchvision .datasets .VisionDataset , Dict [str , Any ]]]:
467471 with super ().create_dataset (
468472 config = config ,
469473 inject_fake_data = inject_fake_data ,
470- disable_download_extract = disable_download_extract ,
474+ patch_checks = patch_checks ,
471475 ** kwargs ,
472476 ) as (dataset , info ):
473477 # PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
@@ -511,26 +515,20 @@ class VideoDatasetTestCase(DatasetTestCase):
511515
512516 def __init__ (self , * args , ** kwargs ):
513517 super ().__init__ (* args , ** kwargs )
514- self .inject_fake_data = self ._set_default_frames_per_clip (self .inject_fake_data )
518+ self .dataset_args = self ._set_default_frames_per_clip (self .dataset_args )
515519
516520 def _set_default_frames_per_clip (self , inject_fake_data ):
517521 argspec = inspect .getfullargspec (self .DATASET_CLASS .__init__ )
518522 args_without_default = argspec .args [1 :- len (argspec .defaults )]
519523 frames_per_clip_last = args_without_default [- 1 ] == "frames_per_clip"
520- only_root_and_frames_per_clip = (len (args_without_default ) == 2 ) and frames_per_clip_last
521524
522525 @functools .wraps (inject_fake_data )
523526 def wrapper (tmpdir , config ):
524- output = inject_fake_data (tmpdir , config )
525- if isinstance (output , collections .abc .Sequence ) and len (output ) == 2 :
526- args , info = output
527- if frames_per_clip_last and len (args ) == len (args_without_default ) - 1 :
528- args = (* args , self .DEFAULT_FRAMES_PER_CLIP )
529- return args , info
530- elif isinstance (output , (int , dict )) and only_root_and_frames_per_clip :
531- return (tmpdir , self .DEFAULT_FRAMES_PER_CLIP )
532- else :
533- return output
527+ args = inject_fake_data (tmpdir , config )
528+ if frames_per_clip_last and len (args ) == len (args_without_default ) - 1 :
529+ args = (* args , self .DEFAULT_FRAMES_PER_CLIP )
530+
531+ return args
534532
535533 return wrapper
536534
@@ -570,7 +568,7 @@ def create_image_file(
570568
571569 image = create_image_or_video_tensor (size )
572570 file = pathlib .Path (root ) / name
573- PIL .Image .fromarray (image .permute (2 , 1 , 0 ).numpy ()).save (file )
571+ PIL .Image .fromarray (image .permute (2 , 1 , 0 ).numpy ()).save (file , ** kwargs )
574572 return file
575573
576574
@@ -706,6 +704,21 @@ def size(idx):
706704 os .makedirs (root )
707705
708706 return [
709- create_video_file (root , file_name_fn (idx ), size = size (idx ) if callable (size ) else size )
707+ create_video_file (root , file_name_fn (idx ), size = size (idx ) if callable (size ) else size , ** kwargs )
710708 for idx in range (num_examples )
711709 ]
710+
711+
712+ def create_random_string (length : int , * digits : str ) -> str :
713+ """Create a random string.
714+
715+ Args:
716+ length (int): Number of characters in the generated string.
717+ *characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
718+ """
719+ if not digits :
720+ digits = string .ascii_lowercase
721+ else :
722+ digits = "" .join (itertools .chain (* digits ))
723+
724+ return "" .join (random .choice (digits ) for _ in range (length ))
0 commit comments