1010import pickle
1111import random
1212import tempfile
13+ import unittest .mock
1314import xml .etree .ElementTree as ET
1415from collections import defaultdict , Counter , UserDict
1516
2122from torch .nn .functional import one_hot
2223from torch .testing import make_tensor as _make_tensor
2324from torchvision .prototype import datasets
24- from torchvision .prototype .datasets ._api import DEFAULT_DECODER_MAP , DEFAULT_DECODER , find
25+ from torchvision .prototype .datasets ._api import find
26+ from torchvision .prototype .utils ._internal import sequence_to_str
2527
2628make_tensor = functools .partial (_make_tensor , device = "cpu" )
2729make_scalar = functools .partial (make_tensor , ())
@@ -49,7 +51,7 @@ class DatasetMock:
4951 def __init__ (self , name , mock_data_fn , * , configs = None ):
5052 self .dataset = find (name )
5153 self .root = TEST_HOME / self .dataset .name
52- self .mock_data_fn = self . _parse_mock_data ( mock_data_fn )
54+ self .mock_data_fn = mock_data_fn
5355 self .configs = configs or self .info ._configs
5456 self ._cache = {}
5557
@@ -61,77 +63,71 @@ def info(self):
6163 def name (self ):
6264 return self .info .name
6365
64- def _parse_mock_data (self , mock_data_fn ):
65- def wrapper (info , root , config ):
66- mock_infos = mock_data_fn (info , root , config )
66+ def _parse_mock_data (self , config , mock_infos ):
67+ if mock_infos is None :
68+ raise pytest .UsageError (
69+ f"The mock data function for dataset '{ self .name } ' returned nothing. It needs to at least return an "
70+ f"integer indicating the number of samples for the current `config`."
71+ )
72+
73+ key_types = set (type (key ) for key in mock_infos ) if isinstance (mock_infos , dict ) else {}
74+ if datasets .utils .DatasetConfig not in key_types :
75+ mock_infos = {config : mock_infos }
76+ elif len (key_types ) > 1 :
77+ raise pytest .UsageError (
78+ f"Unable to handle the returned dictionary of the mock data function for dataset { self .name } . If "
79+ f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
80+ )
6781
68- if mock_infos is None :
82+ for config_ , mock_info in list (mock_infos .items ()):
83+ if config_ in self ._cache :
6984 raise pytest .UsageError (
70- f"The mock data function for dataset ' { self .name } ' returned nothing. It needs to at least return an "
71- f"integer indicating the number of samples for the current `config` ."
85+ f"The mock info for config { config_ } of dataset { self .name } generated for config { config } "
86+ f"already exists in the cache ."
7287 )
73-
74- key_types = set (type (key ) for key in mock_infos ) if isinstance (mock_infos , dict ) else {}
75- if datasets .utils .DatasetConfig not in key_types :
76- mock_infos = {config : mock_infos }
77- elif len (key_types ) > 1 :
88+ if isinstance (mock_info , int ):
89+ mock_infos [config_ ] = dict (num_samples = mock_info )
90+ elif not isinstance (mock_info , dict ):
7891 raise pytest .UsageError (
79- f"Unable to handle the returned dictionary of the mock data function for dataset { self .name } . If "
80- f"returned dictionary uses `DatasetConfig` as key type, all keys should be of that type."
92+ f"The mock data function for dataset '{ self .name } ' returned a { type (mock_infos )} for `config` "
93+ f"{ config_ } . The returned object should be a dictionary containing at least the number of "
94+ f"samples for the key `'num_samples'`. If no additional information is required for specific "
95+ f"tests, the number of samples can also be returned as an integer."
96+ )
97+ elif "num_samples" not in mock_info :
98+ raise pytest .UsageError (
99+ f"The dictionary returned by the mock data function for dataset '{ self .name } ' and config "
100+ f"{ config_ } has to contain a `'num_samples'` entry indicating the number of samples."
81101 )
82102
83- for config_ , mock_info in list (mock_infos .items ()):
84- if config_ in self ._cache :
85- raise pytest .UsageError (
86- f"The mock info for config { config_ } of dataset { self .name } generated for config { config } "
87- f"already exists in the cache."
88- )
89- if isinstance (mock_info , int ):
90- mock_infos [config_ ] = dict (num_samples = mock_info )
91- elif not isinstance (mock_info , dict ):
92- raise pytest .UsageError (
93- f"The mock data function for dataset '{ self .name } ' returned a { type (mock_infos )} for `config` "
94- f"{ config_ } . The returned object should be a dictionary containing at least the number of "
95- f"samples for the key `'num_samples'`. If no additional information is required for specific "
96- f"tests, the number of samples can also be returned as an integer."
97- )
98- elif "num_samples" not in mock_info :
99- raise pytest .UsageError (
100- f"The dictionary returned by the mock data function for dataset '{ self .name } ' and config "
101- f"{ config_ } has to contain a `'num_samples'` entry indicating the number of samples."
102- )
103-
104- return mock_infos
105-
106- return wrapper
103+ return mock_infos
107104
108- def _load_mock (self , config ):
105+ def _prepare_resources (self , config ):
109106 with contextlib .suppress (KeyError ):
110107 return self ._cache [config ]
111108
112109 self .root .mkdir (exist_ok = True )
113- for config_ , mock_info in self .mock_data_fn (self .info , self .root , config ).items ():
114- mock_resources = [
115- ResourceMock (dataset_name = self .name , dataset_config = config_ , file_name = resource .file_name )
116- for resource in self .dataset .resources (config_ )
117- ]
118- self ._cache [config_ ] = (mock_resources , mock_info )
110+ mock_infos = self ._parse_mock_data (config , self .mock_data_fn (self .info , self .root , config ))
111+
112+ available_file_names = {path .name for path in self .root .glob ("*" )}
113+ for config_ , mock_info in mock_infos .items ():
114+ required_file_names = {resource .file_name for resource in self .dataset .resources (config_ )}
115+ missing_file_names = required_file_names - available_file_names
116+ if missing_file_names :
117+ raise pytest .UsageError (
118+ f"Dataset '{ self .name } ' requires the files { sequence_to_str (sorted (missing_file_names ))} "
119+ f"for { config_ } , but they were not created by the mock data function."
120+ )
121+
122+ self ._cache [config_ ] = mock_info
119123
120124 return self ._cache [config ]
121125
122- def load (self , config , * , decoder = DEFAULT_DECODER ):
123- try :
124- self .info .check_dependencies ()
125- except ModuleNotFoundError as error :
126- pytest .skip (str (error ))
127-
128- mock_resources , mock_info = self ._load_mock (config )
129- datapipe = self .dataset ._make_datapipe (
130- [resource .load (self .root ) for resource in mock_resources ],
131- config = config ,
132- decoder = DEFAULT_DECODER_MAP .get (self .info .type ) if decoder is DEFAULT_DECODER else decoder ,
133- )
134- return datapipe , mock_info
126+ @contextlib .contextmanager
127+ def prepare (self , config ):
128+ mock_info = self ._prepare_resources (config )
129+ with unittest .mock .patch ("torchvision.prototype.datasets._api.home" , return_value = str (TEST_HOME )):
130+ yield mock_info
135131
136132
137133def config_id (name , config ):
@@ -1000,7 +996,7 @@ def dtd(info, root, _):
1000996def fer2013 (info , root , config ):
1001997 num_samples = 5 if config .split == "train" else 3
1002998
1003- path = root / f"{ config .split } .txt "
999+ path = root / f"{ config .split } .csv "
10041000 with open (path , "w" , newline = "" ) as file :
10051001 field_names = ["emotion" ] if config .split == "train" else []
10061002 field_names .append ("pixels" )
@@ -1061,7 +1057,7 @@ def clevr(info, root, config):
10611057 file ,
10621058 )
10631059
1064- make_zip (root , f"{ data_folder .name } .zip" )
1060+ make_zip (root , f"{ data_folder .name } .zip" , data_folder )
10651061
10661062 return {config_ : num_samples_map [config_ .split ] for config_ in info ._configs }
10671063
@@ -1121,8 +1117,8 @@ def generate(self, root):
11211117 for path in segmentation_files :
11221118 path .with_name (f".{ path .name } " ).touch ()
11231119
1124- make_tar (root , "images.tar" )
1125- make_tar (root , anns_folder .with_suffix (".tar" ).name )
1120+ make_tar (root , "images.tar.gz" , compression = "gz " )
1121+ make_tar (root , anns_folder .with_suffix (".tar.gz " ).name , compression = "gz" )
11261122
11271123 return num_samples_map
11281124
@@ -1211,7 +1207,7 @@ def _make_segmentations(cls, root, image_files):
12111207 size = [1 , * make_tensor ((2 ,), low = 3 , dtype = torch .int ).tolist ()],
12121208 )
12131209
1214- make_tar (root , segmentations_folder .with_suffix (".tgz" ).name )
1210+ make_tar (root , segmentations_folder .with_suffix (".tgz" ).name , compression = "gz" )
12151211
12161212 @classmethod
12171213 def generate (cls , root ):
@@ -1298,3 +1294,23 @@ def generate(cls, root):
12981294def cub200 (info , root , config ):
12991295 num_samples_map = (CUB2002011MockData if config .year == "2011" else CUB2002010MockData ).generate (root )
13001296 return {config_ : num_samples_map [config_ .split ] for config_ in info ._configs if config_ .year == config .year }
1297+
1298+
1299+ @DATASET_MOCKS .set_from_named_callable
1300+ def svhn (info , root , config ):
1301+ import scipy .io as sio
1302+
1303+ num_samples = {
1304+ "train" : 2 ,
1305+ "test" : 3 ,
1306+ "extra" : 4 ,
1307+ }[config .split ]
1308+
1309+ sio .savemat (
1310+ root / f"{ config .split } _32x32.mat" ,
1311+ {
1312+ "X" : np .random .randint (256 , size = (32 , 32 , 3 , num_samples ), dtype = np .uint8 ),
1313+ "y" : np .random .randint (10 , size = (num_samples ,), dtype = np .uint8 ),
1314+ },
1315+ )
1316+ return num_samples
0 commit comments