66import itertools
77import os
88import pathlib
9+ import random
10+ import string
911import unittest
1012import unittest .mock
1113from typing import Any , Callable , Dict , Iterator , List , Optional , Sequence , Tuple , Union
3234 "create_image_folder" ,
3335 "create_video_file" ,
3436 "create_video_folder" ,
37+ "create_random_string" ,
3538]
3639
3740
@@ -93,14 +96,6 @@ def inner_wrapper(*args, **kwargs):
9396 return outer_wrapper
9497
9598
96- # As of Python 3.7 this is provided by contextlib
97- # https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext
98- # TODO: If the minimum Python requirement is >= 3.7, replace this
99- @contextlib .contextmanager
100- def nullcontext (enter_result = None ):
101- yield enter_result
102-
103-
10499def test_all_configs (test ):
105100 """Decorator to run test against all configurations.
106101
@@ -116,7 +111,7 @@ def test_foo(self, config):
116111
117112 @functools .wraps (test )
118113 def wrapper (self ):
119- for config in self .CONFIGS :
114+ for config in self .CONFIGS or ( self . _DEFAULT_CONFIG ,) :
120115 with self .subTest (** config ):
121116 test (self , config )
122117
@@ -207,6 +202,8 @@ def test_baz(self):
207202 CONFIGS = None
208203 REQUIRED_PACKAGES = None
209204
205+ _DEFAULT_CONFIG = None
206+
210207 _TRANSFORM_KWARGS = {
211208 "transform" ,
212209 "target_transform" ,
@@ -268,7 +265,7 @@ def create_dataset(
268265 self ,
269266 config : Optional [Dict [str , Any ]] = None ,
270267 inject_fake_data : bool = True ,
271- disable_download_extract : Optional [bool ] = None ,
268+ patch_checks : Optional [bool ] = None ,
272269 ** kwargs : Any ,
273270 ) -> Iterator [Tuple [torchvision .datasets .VisionDataset , Dict [str , Any ]]]:
274271 r"""Create the dataset in a temporary directory.
@@ -278,8 +275,8 @@ def create_dataset(
278275 default configuration is used.
279276 inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
280277 creating the dataset.
281- disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating
282- the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``.
278+ patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
279+ omitted defaults to the same value as ``inject_fake_data``.
283280 **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
284281 overlap with ``config``.
285282
@@ -288,43 +285,28 @@ def create_dataset(
288285 info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
289286 for details.
290287 """
291- if config is None :
292- config = self .CONFIGS [0 ].copy ()
288+ default_config = self ._DEFAULT_CONFIG .copy ()
289+ if config is not None :
290+ default_config .update (config )
291+ config = default_config
292+
293+ if patch_checks is None :
294+ patch_checks = inject_fake_data
293295
294296 special_kwargs , other_kwargs = self ._split_kwargs (kwargs )
297+ if "download" in self ._HAS_SPECIAL_KWARG :
298+ special_kwargs ["download" ] = False
295299 config .update (other_kwargs )
296300
297- if disable_download_extract is None :
298- disable_download_extract = inject_fake_data
301+ patchers = self ._patch_download_extract ()
302+ if patch_checks :
303+ patchers .update (self ._patch_checks ())
299304
300305 with get_tmp_dir () as tmpdir :
301306 args = self .dataset_args (tmpdir , config )
307+ info = self ._inject_fake_data (tmpdir , config ) if inject_fake_data else None
302308
303- if inject_fake_data :
304- info = self .inject_fake_data (tmpdir , config )
305- if info is None :
306- raise UsageError (
307- "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
308- "examples for the current configuration."
309- )
310- elif isinstance (info , int ):
311- info = dict (num_examples = info )
312- elif not isinstance (info , dict ):
313- raise UsageError (
314- f"The additional information returned by the method 'inject_fake_data' must be either an "
315- f"integer indicating the number of examples for the current configuration or a dictionary with "
316- f"the same content. Got { type (info )} instead."
317- )
318- elif "num_examples" not in info :
319- raise UsageError (
320- "The information dictionary returned by the method 'inject_fake_data' must contain a "
321- "'num_examples' field that holds the number of examples for the current configuration."
322- )
323- else :
324- info = None
325-
326- cm = self ._disable_download_extract if disable_download_extract else nullcontext
327- with cm (special_kwargs ), disable_console_output ():
309+ with self ._maybe_apply_patches (patchers ), disable_console_output ():
328310 dataset = self .DATASET_CLASS (* args , ** config , ** special_kwargs )
329311
330312 yield dataset , info
@@ -352,19 +334,17 @@ def _verify_required_public_class_attributes(cls):
352334 @classmethod
353335 def _populate_private_class_attributes (cls ):
354336 argspec = inspect .getfullargspec (cls .DATASET_CLASS .__init__ )
337+
338+ cls ._DEFAULT_CONFIG = {
339+ kwarg : default
340+ for kwarg , default in zip (argspec .args [- len (argspec .defaults ):], argspec .defaults )
341+ if kwarg not in cls ._SPECIAL_KWARGS
342+ }
343+
355344 cls ._HAS_SPECIAL_KWARG = {name for name in cls ._SPECIAL_KWARGS if name in argspec .args }
356345
357346 @classmethod
358347 def _process_optional_public_class_attributes (cls ):
359- argspec = inspect .getfullargspec (cls .DATASET_CLASS .__init__ )
360- if cls .CONFIGS is None :
361- config = {
362- kwarg : default
363- for kwarg , default in zip (argspec .args [- len (argspec .defaults ):], argspec .defaults )
364- if kwarg not in cls ._SPECIAL_KWARGS
365- }
366- cls .CONFIGS = (config ,)
367-
368348 if cls .REQUIRED_PACKAGES is not None :
369349 try :
370350 for pkg in cls .REQUIRED_PACKAGES :
@@ -380,28 +360,44 @@ def _split_kwargs(self, kwargs):
380360 other_kwargs = {key : special_kwargs .pop (key ) for key in set (special_kwargs .keys ()) - self ._SPECIAL_KWARGS }
381361 return special_kwargs , other_kwargs
382362
383- @contextlib .contextmanager
384- def _disable_download_extract (self , special_kwargs ):
385- inject_download_kwarg = "download" in self ._HAS_SPECIAL_KWARG and "download" not in special_kwargs
386- if inject_download_kwarg :
387- special_kwargs ["download" ] = False
363+ def _inject_fake_data (self , tmpdir , config ):
364+ info = self .inject_fake_data (tmpdir , config )
365+ if info is None :
366+ raise UsageError (
367+ "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
368+ "examples for the current configuration."
369+ )
370+ elif isinstance (info , int ):
371+ info = dict (num_examples = info )
372+ elif not isinstance (info , dict ):
373+ raise UsageError (
374+ f"The additional information returned by the method 'inject_fake_data' must be either an "
375+ f"integer indicating the number of examples for the current configuration or a dictionary with "
376+ f"the same content. Got { type (info )} instead."
377+ )
378+ elif "num_examples" not in info :
379+ raise UsageError (
380+ "The information dictionary returned by the method 'inject_fake_data' must contain a "
381+ "'num_examples' field that holds the number of examples for the current configuration."
382+ )
383+ return info
384+
385+ def _patch_download_extract (self ):
386+ module = inspect .getmodule (self .DATASET_CLASS ).__name__
387+ return {unittest .mock .patch (f"{ module } .{ function } " ) for function in self ._DOWNLOAD_EXTRACT_FUNCTIONS }
388388
389+ def _patch_checks (self ):
389390 module = inspect .getmodule (self .DATASET_CLASS ).__name__
391+ return {unittest .mock .patch (f"{ module } .{ function } " , return_value = True ) for function in self ._CHECK_FUNCTIONS }
392+
393+ @contextlib .contextmanager
394+ def _maybe_apply_patches (self , patchers ):
390395 with contextlib .ExitStack () as stack :
391396 mocks = {}
392- for function , kwargs in itertools .chain (
393- zip (self ._CHECK_FUNCTIONS , [dict (return_value = True )] * len (self ._CHECK_FUNCTIONS )),
394- zip (self ._DOWNLOAD_EXTRACT_FUNCTIONS , [dict ()] * len (self ._DOWNLOAD_EXTRACT_FUNCTIONS )),
395- ):
397+ for patcher in patchers :
396398 with contextlib .suppress (AttributeError ):
397- patcher = unittest .mock .patch (f"{ module } .{ function } " , ** kwargs )
398- mocks [function ] = stack .enter_context (patcher )
399-
400- try :
401- yield mocks
402- finally :
403- if inject_download_kwarg :
404- del special_kwargs ["download" ]
399+ mocks [patcher .target ] = stack .enter_context (patcher )
400+ yield mocks
405401
406402 def test_not_found_or_corrupted (self ):
407403 with self .assertRaises ((FileNotFoundError , RuntimeError )):
@@ -469,13 +465,13 @@ def create_dataset(
469465 self ,
470466 config : Optional [Dict [str , Any ]] = None ,
471467 inject_fake_data : bool = True ,
472- disable_download_extract : Optional [bool ] = None ,
468+ patch_checks : Optional [bool ] = None ,
473469 ** kwargs : Any ,
474470 ) -> Iterator [Tuple [torchvision .datasets .VisionDataset , Dict [str , Any ]]]:
475471 with super ().create_dataset (
476472 config = config ,
477473 inject_fake_data = inject_fake_data ,
478- disable_download_extract = disable_download_extract ,
474+ patch_checks = patch_checks ,
479475 ** kwargs ,
480476 ) as (dataset , info ):
481477 # PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
@@ -711,3 +707,18 @@ def size(idx):
711707 create_video_file (root , file_name_fn (idx ), size = size (idx ) if callable (size ) else size , ** kwargs )
712708 for idx in range (num_examples )
713709 ]
710+
711+
712+ def create_random_string (length : int , * digits : str ) -> str :
713+ """Create a random string.
714+
715+ Args:
716+ length (int): Number of characters in the generated string.
717+ *characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
718+ """
719+ if not digits :
720+ digits = string .ascii_lowercase
721+ else :
722+ digits = "" .join (itertools .chain (* digits ))
723+
724+ return "" .join (random .choice (digits ) for _ in range (length ))
0 commit comments