@@ -34,11 +34,17 @@ def _make_info(self) -> DatasetInfo:
3434 type = DatasetType .IMAGE ,
3535 categories = categories ,
3636 homepage = "https://www.image-net.org/" ,
37- valid_options = dict (split = ("train" , "val" )),
37+ valid_options = dict (split = ("train" , "val" , "test" )),
3838 extra = dict (
3939 wnid_to_category = FrozenMapping (zip (wnids , categories )),
4040 category_to_wnid = FrozenMapping (zip (categories , wnids )),
41- sizes = FrozenMapping ([(DatasetConfig (split = "train" ), 1281167 ), (DatasetConfig (split = "val" ), 50000 )]),
41+ sizes = FrozenMapping (
42+ [
43+ (DatasetConfig (split = "train" ), 1_281_167 ),
44+ (DatasetConfig (split = "val" ), 50_000 ),
45+ (DatasetConfig (split = "test" ), 100_000 ),
46+ ]
47+ ),
4248 ),
4349 )
4450
@@ -53,17 +59,15 @@ def category_to_wnid(self) -> Dict[str, str]:
5359 def wnid_to_category (self ) -> Dict [str , str ]:
5460 return cast (Dict [str , str ], self .info .extra .wnid_to_category )
5561
62+ _IMAGES_CHECKSUMS = {
63+ "train" : "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb" ,
64+ "val" : "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0" ,
65+ "test_v10102019" : "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4" ,
66+ }
67+
5668 def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
57- if config .split == "train" :
58- images = HttpResource (
59- "ILSVRC2012_img_train.tar" ,
60- sha256 = "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb" ,
61- )
62- else : # config.split == "val"
63- images = HttpResource (
64- "ILSVRC2012_img_val.tar" ,
65- sha256 = "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0" ,
66- )
69+ name = "test_v10102019" if config .split == "test" else config .split
70+ images = HttpResource (f"ILSVRC2012_img_{ name } .tar" , sha256 = self ._IMAGES_CHECKSUMS [name ])
6771
6872 devkit = HttpResource (
6973 "ILSVRC2012_devkit_t12.tar.gz" ,
@@ -81,11 +85,11 @@ def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[int, s
8185 label = self .categories .index (category )
8286 return (label , category , wnid ), data
8387
84- _VAL_IMAGE_NAME_PATTERN = re .compile (r"ILSVRC2012_val_ (?P<id>\d{8})[.]JPEG" )
88+ _VAL_TEST_IMAGE_NAME_PATTERN = re .compile (r"ILSVRC2012_(val|test)_ (?P<id>\d{8})[.]JPEG" )
8589
86- def _val_image_key (self , data : Tuple [str , Any ]) -> int :
90+ def _val_test_image_key (self , data : Tuple [str , Any ]) -> int :
8791 path = pathlib .Path (data [0 ])
88- return int (self ._VAL_IMAGE_NAME_PATTERN .match (path .name ).group ("id" )) # type: ignore[union-attr]
92+ return int (self ._VAL_TEST_IMAGE_NAME_PATTERN .match (path .name ).group ("id" )) # type: ignore[union-attr]
8993
9094 def _collate_val_data (
9195 self , data : Tuple [Tuple [int , int ], Tuple [str , io .IOBase ]]
@@ -96,9 +100,12 @@ def _collate_val_data(
96100 wnid = self .category_to_wnid [category ]
97101 return (label , category , wnid ), image_data
98102
103+ def _collate_test_data (self , data : Tuple [str , io .IOBase ]) -> Tuple [Tuple [None , None , None ], Tuple [str , io .IOBase ]]:
104+ return (None , None , None ), data
105+
99106 def _collate_and_decode_sample (
100107 self ,
101- data : Tuple [Tuple [int , str , str ], Tuple [str , io .IOBase ]],
108+ data : Tuple [Tuple [Optional [ int ], Optional [ str ], Optional [ str ] ], Tuple [str , io .IOBase ]],
102109 * ,
103110 decoder : Optional [Callable [[io .IOBase ], torch .Tensor ]],
104111 ) -> Dict [str , Any ]:
@@ -108,7 +115,7 @@ def _collate_and_decode_sample(
108115 return dict (
109116 path = path ,
110117 image = decoder (buffer ) if decoder else buffer ,
111- label = torch . tensor ( label ) ,
118+ label = label ,
112119 category = category ,
113120 wnid = wnid ,
114121 )
@@ -129,7 +136,7 @@ def _make_datapipe(
129136 dp = TarArchiveReader (images_dp )
130137 dp = Shuffler (dp , buffer_size = INFINITE_BUFFER_SIZE )
131138 dp = Mapper (dp , self ._collate_train_data )
132- else :
139+ elif config . split == "val" :
133140 devkit_dp = TarArchiveReader (devkit_dp )
134141 devkit_dp = Filter (devkit_dp , path_comparator ("name" , "ILSVRC2012_validation_ground_truth.txt" ))
135142 devkit_dp = LineReader (devkit_dp , return_path = False )
@@ -141,10 +148,13 @@ def _make_datapipe(
141148 devkit_dp ,
142149 images_dp ,
143150 key_fn = getitem (0 ),
144- ref_key_fn = self ._val_image_key ,
151+ ref_key_fn = self ._val_test_image_key ,
145152 buffer_size = INFINITE_BUFFER_SIZE ,
146153 )
147154 dp = Mapper (dp , self ._collate_val_data )
155+ else : # config.split == "test"
156+ dp = Shuffler (images_dp , buffer_size = INFINITE_BUFFER_SIZE )
157+ dp = Mapper (dp , self ._collate_test_data )
148158
149159 return Mapper (dp , self ._collate_and_decode_sample , fn_kwargs = dict (decoder = decoder ))
150160
0 commit comments