1- import  functools 
21import  pathlib 
32import  re 
43from  collections  import  OrderedDict 
5- from  typing  import  Any , Dict , List , Optional , Tuple , cast , BinaryIO 
4+ from  collections  import  defaultdict 
5+ from  typing  import  Any , Dict , List , Optional , Tuple , cast , BinaryIO , Union 
66
77import  torch 
88from  torchdata .datapipes .iter  import  (
1616    UnBatcher ,
1717)
1818from  torchvision .prototype .datasets .utils  import  (
19-     Dataset ,
20-     DatasetConfig ,
2119    DatasetInfo ,
2220    HttpResource ,
2321    OnlineResource ,
22+     Dataset2 ,
2423)
2524from  torchvision .prototype .datasets .utils ._internal  import  (
2625    MappingIterator ,
3231    hint_shuffling ,
3332)
3433from  torchvision .prototype .features  import  BoundingBox , Label , _Feature , EncodedImage 
35- from  torchvision .prototype .utils ._internal  import  FrozenMapping 
36- 
37- 
38- class  Coco (Dataset ):
39-     def  _make_info (self ) ->  DatasetInfo :
40-         name  =  "coco" 
41-         categories , super_categories  =  zip (* DatasetInfo .read_categories_file (BUILTIN_DIR  /  f"{ name }  ))
42- 
43-         return  DatasetInfo (
44-             name ,
45-             dependencies = ("pycocotools" ,),
46-             categories = categories ,
47-             homepage = "https://cocodataset.org/" ,
48-             valid_options = dict (
49-                 split = ("train" , "val" ),
50-                 year = ("2017" , "2014" ),
51-                 annotations = (* self ._ANN_DECODERS .keys (), None ),
52-             ),
53-             extra = dict (category_to_super_category = FrozenMapping (zip (categories , super_categories ))),
34+ 
35+ from  .._api  import  register_dataset , register_info 
36+ 
37+ 
38+ NAME  =  "coco" 
39+ 
40+ 
41+ @register_info (NAME ) 
42+ def  _info () ->  Dict [str , Any ]:
43+     categories , super_categories  =  zip (* DatasetInfo .read_categories_file (BUILTIN_DIR  /  f"{ NAME }  ))
44+     return  dict (categories = categories , super_categories = super_categories )
45+ 
46+ 
47+ @register_dataset (NAME ) 
48+ class  Coco (Dataset2 ):
49+     """ 
50+     - **homepage**: https://cocodataset.org/ 
51+     - **dependencies**: 
52+         - <pycocotools `https://github.com/cocodataset/cocoapi`>_ 
53+     """ 
54+ 
55+     def  __init__ (
56+         self ,
57+         root : Union [str , pathlib .Path ],
58+         * ,
59+         split : str  =  "train" ,
60+         year : str  =  "2017" ,
61+         annotations : Optional [str ] =  "instances" ,
62+         skip_integrity_check : bool  =  False ,
63+     ) ->  None :
64+         self ._split  =  self ._verify_str_arg (split , "split" , {"train" , "val" })
65+         self ._year  =  self ._verify_str_arg (year , "year" , {"2017" , "2014" })
66+         self ._annotations  =  (
67+             self ._verify_str_arg (annotations , "annotations" , self ._ANN_DECODERS .keys ())
68+             if  annotations  is  not None 
69+             else  None 
5470        )
5571
72+         info  =  _info ()
73+         categories , super_categories  =  info ["categories" ], info ["super_categories" ]
74+         self ._categories  =  categories 
75+         self ._category_to_super_category  =  dict (zip (categories , super_categories ))
76+ 
77+         super ().__init__ (root , dependencies = ("pycocotools" ,), skip_integrity_check = skip_integrity_check )
78+ 
5679    _IMAGE_URL_BASE  =  "http://images.cocodataset.org/zips" 
5780
5881    _IMAGES_CHECKSUMS  =  {
@@ -69,14 +92,14 @@ def _make_info(self) -> DatasetInfo:
6992        "2017" : "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268" ,
7093    }
7194
72-     def  resources (self ,  config :  DatasetConfig ) ->  List [OnlineResource ]:
95+     def  _resources (self ) ->  List [OnlineResource ]:
7396        images  =  HttpResource (
74-             f"{ self ._IMAGE_URL_BASE } { config . split } { config . year }  ,
75-             sha256 = self ._IMAGES_CHECKSUMS [(config . year ,  config . split )],
97+             f"{ self ._IMAGE_URL_BASE } { self . _split } { self . _year }  ,
98+             sha256 = self ._IMAGES_CHECKSUMS [(self . _year ,  self . _split )],
7699        )
77100        meta  =  HttpResource (
78-             f"{ self ._META_URL_BASE } { config . year }  ,
79-             sha256 = self ._META_CHECKSUMS [config . year ],
101+             f"{ self ._META_URL_BASE } { self . _year }  ,
102+             sha256 = self ._META_CHECKSUMS [self . _year ],
80103        )
81104        return  [images , meta ]
82105
@@ -110,10 +133,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
110133                format = "xywh" ,
111134                image_size = image_size ,
112135            ),
113-             labels = Label (labels , categories = self .categories ),
114-             super_categories = [
115-                 self .info .extra .category_to_super_category [self .info .categories [label ]] for  label  in  labels 
116-             ],
136+             labels = Label (labels , categories = self ._categories ),
137+             super_categories = [self ._category_to_super_category [self ._categories [label ]] for  label  in  labels ],
117138            ann_ids = [ann ["id" ] for  ann  in  anns ],
118139        )
119140
@@ -134,9 +155,14 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str,
134155        fr"(?P<annotations>({ '|' .join (_ANN_DECODERS .keys ())}  
135156    )
136157
137-     def  _filter_meta_files (self , data : Tuple [str , Any ],  * ,  split :  str ,  year :  str ,  annotations :  str ) ->  bool :
158+     def  _filter_meta_files (self , data : Tuple [str , Any ]) ->  bool :
138159        match  =  self ._META_FILE_PATTERN .match (pathlib .Path (data [0 ]).name )
139-         return  bool (match  and  match ["split" ] ==  split  and  match ["year" ] ==  year  and  match ["annotations" ] ==  annotations )
160+         return  bool (
161+             match 
162+             and  match ["split" ] ==  self ._split 
163+             and  match ["year" ] ==  self ._year 
164+             and  match ["annotations" ] ==  self ._annotations 
165+         )
140166
141167    def  _classify_meta (self , data : Tuple [str , Any ]) ->  Optional [int ]:
142168        key , _  =  data 
@@ -157,38 +183,26 @@ def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
157183    def  _prepare_sample (
158184        self ,
159185        data : Tuple [Tuple [List [Dict [str , Any ]], Dict [str , Any ]], Tuple [str , BinaryIO ]],
160-         * ,
161-         annotations : str ,
162186    ) ->  Dict [str , Any ]:
163187        ann_data , image_data  =  data 
164188        anns , image_meta  =  ann_data 
165189
166190        sample  =  self ._prepare_image (image_data )
191+         # this method is only called if we have annotations 
192+         annotations  =  cast (str , self ._annotations )
167193        sample .update (self ._ANN_DECODERS [annotations ](self , anns , image_meta ))
168194        return  sample 
169195
170-     def  _make_datapipe (
171-         self ,
172-         resource_dps : List [IterDataPipe ],
173-         * ,
174-         config : DatasetConfig ,
175-     ) ->  IterDataPipe [Dict [str , Any ]]:
196+     def  _datapipe (self , resource_dps : List [IterDataPipe ]) ->  IterDataPipe [Dict [str , Any ]]:
176197        images_dp , meta_dp  =  resource_dps 
177198
178-         if  config . annotations  is  None :
199+         if  self . _annotations  is  None :
179200            dp  =  hint_shuffling (images_dp )
180201            dp  =  hint_sharding (dp )
202+             dp  =  hint_shuffling (dp )
181203            return  Mapper (dp , self ._prepare_image )
182204
183-         meta_dp  =  Filter (
184-             meta_dp ,
185-             functools .partial (
186-                 self ._filter_meta_files ,
187-                 split = config .split ,
188-                 year = config .year ,
189-                 annotations = config .annotations ,
190-             ),
191-         )
205+         meta_dp  =  Filter (meta_dp , self ._filter_meta_files )
192206        meta_dp  =  JsonParser (meta_dp )
193207        meta_dp  =  Mapper (meta_dp , getitem (1 ))
194208        meta_dp : IterDataPipe [Dict [str , Dict [str , Any ]]] =  MappingIterator (meta_dp )
@@ -216,26 +230,31 @@ def _make_datapipe(
216230            ref_key_fn = getitem ("id" ),
217231            buffer_size = INFINITE_BUFFER_SIZE ,
218232        )
219- 
220233        dp  =  IterKeyZipper (
221234            anns_dp ,
222235            images_dp ,
223236            key_fn = getitem (1 , "file_name" ),
224237            ref_key_fn = path_accessor ("name" ),
225238            buffer_size = INFINITE_BUFFER_SIZE ,
226239        )
240+         return  Mapper (dp , self ._prepare_sample )
241+ 
242+     def  __len__ (self ) ->  int :
243+         return  {
244+             ("train" , "2017" ): defaultdict (lambda : 118_287 , instances = 117_266 ),
245+             ("train" , "2014" ): defaultdict (lambda : 82_783 , instances = 82_081 ),
246+             ("val" , "2017" ): defaultdict (lambda : 5_000 , instances = 4_952 ),
247+             ("val" , "2014" ): defaultdict (lambda : 40_504 , instances = 40_137 ),
248+         }[(self ._split , self ._year )][
249+             self ._annotations   # type: ignore[index] 
250+         ]
227251
228-         return  Mapper (dp , functools .partial (self ._prepare_sample , annotations = config .annotations ))
229- 
230-     def  _generate_categories (self , root : pathlib .Path ) ->  Tuple [Tuple [str , str ]]:
231-         config  =  self .default_config 
232-         resources  =  self .resources (config )
252+     def  _generate_categories (self ) ->  Tuple [Tuple [str , str ]]:
253+         self ._annotations  =  "instances" 
254+         resources  =  self ._resources ()
233255
234-         dp  =  resources [1 ].load (root )
235-         dp  =  Filter (
236-             dp ,
237-             functools .partial (self ._filter_meta_files , split = config .split , year = config .year , annotations = "instances" ),
238-         )
256+         dp  =  resources [1 ].load (self ._root )
257+         dp  =  Filter (dp , self ._filter_meta_files )
239258        dp  =  JsonParser (dp )
240259
241260        _ , meta  =  next (iter (dp ))
0 commit comments