@@ -117,19 +117,45 @@ def inner_wrapper(*args, **kwargs):
117117def test_all_configs (test ):
118118 """Decorator to run test against all configurations.
119119
120- Add this as decorator to an arbitrary test to run it against all configurations. The current configuration is
121- provided as the first parameter:
120+ Add this as decorator to an arbitrary test to run it against all configurations. This includes
121+ :attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.
122+
123+ The current configuration is provided as the first parameter for the test:
122124
123125 .. code-block::
124126
125- @test_all_configs
127+ @test_all_configs()
126128 def test_foo(self, config):
127129 pass
130+
131+ .. note::
132+
133+ This will try to remove duplicate configurations. During this process it will not not preserve a potential
134+ ordering of the configurations or an inner ordering of a configuration.
128135 """
129136
137+ def maybe_remove_duplicates (configs ):
138+ try :
139+ return [dict (config_ ) for config_ in set (tuple (sorted (config .items ())) for config in configs )]
140+ except TypeError :
141+ # A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
142+ # removal would be a lot more elaborate and we simply bail out.
143+ return configs
144+
130145 @functools .wraps (test )
131146 def wrapper (self ):
132- for config in self .CONFIGS or (self ._DEFAULT_CONFIG ,):
147+ configs = []
148+ if self .DEFAULT_CONFIG is not None :
149+ configs .append (self .DEFAULT_CONFIG )
150+ if self .ADDITIONAL_CONFIGS is not None :
151+ configs .extend (self .ADDITIONAL_CONFIGS )
152+
153+ if not configs :
154+ configs = [self ._KWARG_DEFAULTS .copy ()]
155+ else :
156+ configs = maybe_remove_duplicates (configs )
157+
158+ for config in configs :
133159 with self .subTest (** config ):
134160 test (self , config )
135161
@@ -166,9 +192,13 @@ class DatasetTestCase(unittest.TestCase):
166192
167193 Optionally, you can overwrite the following class attributes:
168194
169- - CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictonary can contain an
170- arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
171- ``transforms``, or ``download``. The first element will be used as default configuration.
195+ - DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
196+ keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
197+ ``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
198+ not provide one.
199+ - ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
200+ contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
201+ ``transforms``, or ``download``.
172202 - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
173203 available, the tests are skipped.
174204
@@ -218,22 +248,31 @@ def test_baz(self):
218248 DATASET_CLASS = None
219249 FEATURE_TYPES = None
220250
221- CONFIGS = None
251+ DEFAULT_CONFIG = None
252+ ADDITIONAL_CONFIGS = None
222253 REQUIRED_PACKAGES = None
223254
224- _DEFAULT_CONFIG = None
225-
255+ # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
226256 _TRANSFORM_KWARGS = {
227257 "transform" ,
228258 "target_transform" ,
229259 "transforms" ,
230260 }
261+ # These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
231262 _SPECIAL_KWARGS = {
232263 * _TRANSFORM_KWARGS ,
233264 "download" ,
234265 }
266+
267+ # These fields are populated during setupClass() within _populate_private_class_attributes()
268+
269+ # This will be a dictionary containing all keyword arguments with their respective default values extracted from
270+ # the dataset constructor.
271+ _KWARG_DEFAULTS = None
272+ # This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
235273 _HAS_SPECIAL_KWARG = None
236274
275+ # These functions are disabled during dataset creation in create_dataset().
237276 _CHECK_FUNCTIONS = {
238277 "check_md5" ,
239278 "check_integrity" ,
@@ -256,7 +295,8 @@ def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]:
256295 Args:
257296 tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
258297 to be created and in turn also for the fake data injected here.
259- config (Dict[str, Any]): Configuration that will be used to create the dataset.
298+ config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
299+ fields for all dataset parameters with default values.
260300
261301 Returns:
262302 (Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
@@ -273,7 +313,8 @@ def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Di
273313 Args:
274314 tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
275315 to be created and in turn also for the fake data injected here.
276- config (Dict[str, Any]): Configuration that will be used to create the dataset.
316+ config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
317+ fields for all dataset parameters with default values.
277318
278319 Needs to return one of the following:
279320
@@ -293,9 +334,16 @@ def create_dataset(
293334 ) -> Iterator [Tuple [torchvision .datasets .VisionDataset , Dict [str , Any ]]]:
294335 r"""Create the dataset in a temporary directory.
295336
337+ The configuration passed to the dataset is populated to contain at least all parameters with default values.
338+ For this the following order of precedence is used:
339+
340+ 1. Parameters in :attr:`kwargs`.
341+ 2. Configuration in :attr:`config`.
342+ 3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
343+ 4. Default parameters of the dataset.
344+
296345 Args:
297- config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the
298- default configuration is used.
346+ config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
299347 inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
300348 creating the dataset.
301349 patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
@@ -308,30 +356,33 @@ def create_dataset(
308356 info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
309357 for details.
310358 """
311- default_config = self ._DEFAULT_CONFIG .copy ()
312- if config is not None :
313- default_config .update (config )
314- config = default_config
315-
316359 if patch_checks is None :
317360 patch_checks = inject_fake_data
318361
319362 special_kwargs , other_kwargs = self ._split_kwargs (kwargs )
363+
364+ complete_config = self ._KWARG_DEFAULTS .copy ()
365+ if self .DEFAULT_CONFIG :
366+ complete_config .update (self .DEFAULT_CONFIG )
367+ if config :
368+ complete_config .update (config )
369+ if other_kwargs :
370+ complete_config .update (other_kwargs )
371+
320372 if "download" in self ._HAS_SPECIAL_KWARG and special_kwargs .get ("download" , False ):
321373 # override download param to False param if its default is truthy
322374 special_kwargs ["download" ] = False
323- config .update (other_kwargs )
324375
325376 patchers = self ._patch_download_extract ()
326377 if patch_checks :
327378 patchers .update (self ._patch_checks ())
328379
329380 with get_tmp_dir () as tmpdir :
330- args = self .dataset_args (tmpdir , config )
331- info = self ._inject_fake_data (tmpdir , config ) if inject_fake_data else None
381+ args = self .dataset_args (tmpdir , complete_config )
382+ info = self ._inject_fake_data (tmpdir , complete_config ) if inject_fake_data else None
332383
333384 with self ._maybe_apply_patches (patchers ), disable_console_output ():
334- dataset = self .DATASET_CLASS (* args , ** config , ** special_kwargs )
385+ dataset = self .DATASET_CLASS (* args , ** complete_config , ** special_kwargs )
335386
336387 yield dataset , info
337388
@@ -357,26 +408,69 @@ def _verify_required_public_class_attributes(cls):
357408
358409 @classmethod
359410 def _populate_private_class_attributes (cls ):
360- argspec = inspect .getfullargspec (cls .DATASET_CLASS .__init__ )
411+ defaults = []
412+ for cls_ in cls .DATASET_CLASS .__mro__ :
413+ if cls_ is torchvision .datasets .VisionDataset :
414+ break
361415
362- cls ._DEFAULT_CONFIG = {
363- kwarg : default
364- for kwarg , default in zip (argspec .args [- len (argspec .defaults ):], argspec .defaults )
365- if kwarg not in cls ._SPECIAL_KWARGS
366- }
416+ argspec = inspect .getfullargspec (cls_ .__init__ )
367417
368- cls ._HAS_SPECIAL_KWARG = {name for name in cls ._SPECIAL_KWARGS if name in argspec .args }
418+ if not argspec .defaults :
419+ continue
420+
421+ defaults .append (
422+ {kwarg : default for kwarg , default in zip (argspec .args [- len (argspec .defaults ):], argspec .defaults )}
423+ )
424+
425+ if not argspec .varkw :
426+ break
427+
428+ kwarg_defaults = dict ()
429+ for config in reversed (defaults ):
430+ kwarg_defaults .update (config )
431+
432+ has_special_kwargs = set ()
433+ for name in cls ._SPECIAL_KWARGS :
434+ if name not in kwarg_defaults :
435+ continue
436+
437+ del kwarg_defaults [name ]
438+ has_special_kwargs .add (name )
439+
440+ cls ._KWARG_DEFAULTS = kwarg_defaults
441+ cls ._HAS_SPECIAL_KWARG = has_special_kwargs
369442
370443 @classmethod
371444 def _process_optional_public_class_attributes (cls ):
372- if cls .REQUIRED_PACKAGES is not None :
373- try :
374- for pkg in cls .REQUIRED_PACKAGES :
445+ def check_config (config , name ):
446+ special_kwargs = tuple (f"'{ name } '" for name in cls ._SPECIAL_KWARGS if name in config )
447+ if special_kwargs :
448+ raise UsageError (
449+ f"{ name } contains a value for the parameter(s) { ', ' .join (special_kwargs )} . "
450+ f"These are handled separately by the test case and should not be set here. "
451+ f"If you need to test some custom behavior regarding these parameters, "
452+ f"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
453+ )
454+
455+ if cls .DEFAULT_CONFIG is not None :
456+ check_config (cls .DEFAULT_CONFIG , "DEFAULT_CONFIG" )
457+
458+ if cls .ADDITIONAL_CONFIGS is not None :
459+ for idx , config in enumerate (cls .ADDITIONAL_CONFIGS ):
460+ check_config (config , f"CONFIGS[{ idx } ]" )
461+
462+ if cls .REQUIRED_PACKAGES :
463+ missing_pkgs = []
464+ for pkg in cls .REQUIRED_PACKAGES :
465+ try :
375466 importlib .import_module (pkg )
376- except ImportError as error :
467+ except ImportError :
468+ missing_pkgs .append (f"'{ pkg } '" )
469+
470+ if missing_pkgs :
377471 raise unittest .SkipTest (
378- f"The package ' { error . name } ' is required to load the dataset ' { cls . DATASET_CLASS . __name__ } ' but is "
379- f"not installed."
472+ f"The package(s) { ', ' . join ( missing_pkgs ) } are required to load the dataset "
473+ f"' { cls . DATASET_CLASS . __name__ } ', but are not installed."
380474 )
381475
382476 def _split_kwargs (self , kwargs ):
0 commit comments