3535
3636NAME = "imagenet"
3737
38+ CATEGORIES , WNIDS = cast (
39+ Tuple [Tuple [str , ...], Tuple [str , ...]],
40+ zip (* DatasetInfo .read_categories_file (BUILTIN_DIR / f"{ NAME } .categories" )),
41+ )
42+ WNID_TO_CATEGORY = dict (zip (WNIDS , CATEGORIES ))
43+
3844
3945@register_info (NAME )
4046def _info () -> Dict [str , Any ]:
41- categories , wnids = zip (* DatasetInfo .read_categories_file (BUILTIN_DIR / f"{ NAME } .categories" ))
42- return dict (categories = categories , wnids = wnids )
47+ return dict (categories = CATEGORIES , wnids = WNIDS )
4348
4449
4550class ImageNetResource (ManualDownloadResource ):
4651 def __init__ (self , ** kwargs : Any ) -> None :
4752 super ().__init__ ("Register on https://image-net.org/ and follow the instructions there." , ** kwargs )
4853
4954
55+ TRAIN_IMAGE_NAME_PATTERN = re .compile (r"(?P<wnid>n\d{8})_\d+[.]JPEG" )
56+
57+
58+ def prepare_train_data (data : Tuple [str , BinaryIO ]) -> Tuple [Tuple [Label , str ], Tuple [str , BinaryIO ]]:
59+ path = pathlib .Path (data [0 ])
60+ wnid = cast (Match [str ], TRAIN_IMAGE_NAME_PATTERN .match (path .name ))["wnid" ]
61+ label = Label .from_category (WNID_TO_CATEGORY [wnid ], categories = CATEGORIES )
62+ return (label , wnid ), data
63+
64+
65+ def prepare_test_data (data : Tuple [str , BinaryIO ]) -> Tuple [None , Tuple [str , BinaryIO ]]:
66+ return None , data
67+
68+
69+ def classifiy_devkit (data : Tuple [str , BinaryIO ]) -> Optional [int ]:
70+ return {
71+ "meta.mat" : 0 ,
72+ "ILSVRC2012_validation_ground_truth.txt" : 1 ,
73+ }.get (pathlib .Path (data [0 ]).name )
74+
75+
76+ # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
77+ # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
78+ WNID_MAP = {
79+ "n03126707" : "construction crane" ,
80+ "n03710721" : "tank suit" ,
81+ }
82+
83+
84+ def extract_categories_and_wnids (data : Tuple [str , BinaryIO ]) -> List [Tuple [str , str ]]:
85+ synsets = read_mat (data [1 ], squeeze_me = True )["synsets" ]
86+ return [
87+ (WNID_MAP .get (wnid , category .split ("," , 1 )[0 ]), wnid )
88+ for _ , wnid , category , _ , num_children , * _ in synsets
89+ # if num_children > 0, we are looking at a superclass that has no direct instance
90+ if num_children == 0
91+ ]
92+
93+
94+ def imagenet_label_to_wnid (imagenet_label : str ) -> str :
95+ return WNIDS [int (imagenet_label ) - 1 ]
96+
97+
98+ VAL_TEST_IMAGE_NAME_PATTERN = re .compile (r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG" )
99+
100+
101+ def val_test_image_key (path : pathlib .Path ) -> int :
102+ return int (VAL_TEST_IMAGE_NAME_PATTERN .match (path .name )["id" ]) # type: ignore[index]
103+
104+
105+ def prepare_val_data (
106+ data : Tuple [Tuple [int , str ], Tuple [str , BinaryIO ]]
107+ ) -> Tuple [Tuple [Label , str ], Tuple [str , BinaryIO ]]:
108+ label_data , image_data = data
109+ _ , wnid = label_data
110+ label = Label .from_category (WNID_TO_CATEGORY [wnid ], categories = CATEGORIES )
111+ return (label , wnid ), image_data
112+
113+
114+ def prepare_sample (data : Tuple [Optional [Tuple [Label , str ]], Tuple [str , BinaryIO ]]) -> Dict [str , Any ]:
115+ label_data , (path , buffer ) = data
116+
117+ return dict (
118+ dict (zip (("label" , "wnid" ), label_data if label_data else (None , None ))),
119+ path = path ,
120+ image = EncodedImage .from_file (buffer ),
121+ )
122+
123+
50124@register_dataset (NAME )
51125class ImageNet (Dataset2 ):
52126 def __init__ (self , root : Union [str , pathlib .Path ], * , split : str = "train" ) -> None :
@@ -83,67 +157,6 @@ def _resources(self) -> List[OnlineResource]:
83157
84158 return resources
85159
86- _TRAIN_IMAGE_NAME_PATTERN = re .compile (r"(?P<wnid>n\d{8})_\d+[.]JPEG" )
87-
88- def _prepare_train_data (self , data : Tuple [str , BinaryIO ]) -> Tuple [Tuple [Label , str ], Tuple [str , BinaryIO ]]:
89- path = pathlib .Path (data [0 ])
90- wnid = cast (Match [str ], self ._TRAIN_IMAGE_NAME_PATTERN .match (path .name ))["wnid" ]
91- label = Label .from_category (self ._wnid_to_category [wnid ], categories = self ._categories )
92- return (label , wnid ), data
93-
94- def _prepare_test_data (self , data : Tuple [str , BinaryIO ]) -> Tuple [None , Tuple [str , BinaryIO ]]:
95- return None , data
96-
97- def _classifiy_devkit (self , data : Tuple [str , BinaryIO ]) -> Optional [int ]:
98- return {
99- "meta.mat" : 0 ,
100- "ILSVRC2012_validation_ground_truth.txt" : 1 ,
101- }.get (pathlib .Path (data [0 ]).name )
102-
103- # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
104- # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
105- _WNID_MAP = {
106- "n03126707" : "construction crane" ,
107- "n03710721" : "tank suit" ,
108- }
109-
110- def _extract_categories_and_wnids (self , data : Tuple [str , BinaryIO ]) -> List [Tuple [str , str ]]:
111- synsets = read_mat (data [1 ], squeeze_me = True )["synsets" ]
112- return [
113- (self ._WNID_MAP .get (wnid , category .split ("," , 1 )[0 ]), wnid )
114- for _ , wnid , category , _ , num_children , * _ in synsets
115- # if num_children > 0, we are looking at a superclass that has no direct instance
116- if num_children == 0
117- ]
118-
119- def _imagenet_label_to_wnid (self , imagenet_label : str ) -> str :
120- return self ._wnids [int (imagenet_label ) - 1 ]
121-
122- _VAL_TEST_IMAGE_NAME_PATTERN = re .compile (r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG" )
123-
124- def _val_test_image_key (self , path : pathlib .Path ) -> int :
125- return int (self ._VAL_TEST_IMAGE_NAME_PATTERN .match (path .name )["id" ]) # type: ignore[index]
126-
127- def _prepare_val_data (
128- self , data : Tuple [Tuple [int , str ], Tuple [str , BinaryIO ]]
129- ) -> Tuple [Tuple [Label , str ], Tuple [str , BinaryIO ]]:
130- label_data , image_data = data
131- _ , wnid = label_data
132- label = Label .from_category (self ._wnid_to_category [wnid ], categories = self ._categories )
133- return (label , wnid ), image_data
134-
135- def _prepare_sample (
136- self ,
137- data : Tuple [Optional [Tuple [Label , str ]], Tuple [str , BinaryIO ]],
138- ) -> Dict [str , Any ]:
139- label_data , (path , buffer ) = data
140-
141- return dict (
142- dict (zip (("label" , "wnid" ), label_data if label_data else (None , None ))),
143- path = path ,
144- image = EncodedImage .from_file (buffer ),
145- )
146-
147160 def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
148161 if self ._split in {"train" , "test" }:
149162 dp = resource_dps [0 ]
@@ -154,19 +167,19 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
154167
155168 dp = hint_shuffling (dp )
156169 dp = hint_sharding (dp )
157- dp = Mapper (dp , self . _prepare_train_data if self ._split == "train" else self . _prepare_test_data )
170+ dp = Mapper (dp , prepare_train_data if self ._split == "train" else prepare_test_data )
158171 else : # config.split == "val":
159172 images_dp , devkit_dp = resource_dps
160173
161174 meta_dp , label_dp = Demultiplexer (
162- devkit_dp , 2 , self . _classifiy_devkit , drop_none = True , buffer_size = INFINITE_BUFFER_SIZE
175+ devkit_dp , 2 , classifiy_devkit , drop_none = True , buffer_size = INFINITE_BUFFER_SIZE
163176 )
164177
165- meta_dp = Mapper (meta_dp , self . _extract_categories_and_wnids )
178+ meta_dp = Mapper (meta_dp , extract_categories_and_wnids )
166179 _ , wnids = zip (* next (iter (meta_dp )))
167180
168181 label_dp = LineReader (label_dp , decode = True , return_path = False )
169- label_dp = Mapper (label_dp , self . _imagenet_label_to_wnid )
182+ label_dp = Mapper (label_dp , imagenet_label_to_wnid )
170183 label_dp : IterDataPipe [Tuple [int , str ]] = Enumerator (label_dp , 1 )
171184 label_dp = hint_shuffling (label_dp )
172185 label_dp = hint_sharding (label_dp )
@@ -175,12 +188,12 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
175188 label_dp ,
176189 images_dp ,
177190 key_fn = getitem (0 ),
178- ref_key_fn = path_accessor (self . _val_test_image_key ),
191+ ref_key_fn = path_accessor (val_test_image_key ),
179192 buffer_size = INFINITE_BUFFER_SIZE ,
180193 )
181- dp = Mapper (dp , self . _prepare_val_data )
194+ dp = Mapper (dp , prepare_val_data )
182195
183- return Mapper (dp , self . _prepare_sample )
196+ return Mapper (dp , prepare_sample )
184197
185198 def __len__ (self ) -> int :
186199 return {
@@ -195,7 +208,7 @@ def _generate_categories(self) -> List[Tuple[str, ...]]:
195208
196209 devkit_dp = resources [1 ].load (self ._root )
197210 meta_dp = Filter (devkit_dp , path_comparator ("name" , "meta.mat" ))
198- meta_dp = Mapper (meta_dp , self . _extract_categories_and_wnids )
211+ meta_dp = Mapper (meta_dp , extract_categories_and_wnids )
199212
200213 categories_and_wnids = cast (List [Tuple [str , ...]], next (iter (meta_dp )))
201214 categories_and_wnids .sort (key = lambda category_and_wnid : category_and_wnid [1 ])
0 commit comments