11import  csv 
22import  functools 
33import  pathlib 
4- from  typing  import  Any , Dict , List , Optional , Tuple , BinaryIO , Callable 
4+ from  typing  import  Any , Dict , List , Optional , Tuple , BinaryIO , Callable ,  Union 
55
66from  torchdata .datapipes .iter  import  (
77    IterDataPipe ,
1414    CSVDictParser ,
1515)
1616from  torchvision .prototype .datasets .utils  import  (
17-     Dataset ,
18-     DatasetConfig ,
17+     Dataset2 ,
1918    DatasetInfo ,
2019    HttpResource ,
2120    OnlineResource ,
2827    getitem ,
2928    path_comparator ,
3029    path_accessor ,
30+     BUILTIN_DIR ,
3131)
3232from  torchvision .prototype .features  import  Label , BoundingBox , _Feature , EncodedImage 
3333
34+ from  .._api  import  register_dataset , register_info 
35+ 
3436csv .register_dialect ("cub200" , delimiter = " " )
3537
3638
37- class  CUB200 (Dataset ):
38-     def  _make_info (self ) ->  DatasetInfo :
39-         return  DatasetInfo (
40-             "cub200" ,
41-             homepage = "http://www.vision.caltech.edu/visipedia/CUB-200-2011.html" ,
42-             dependencies = ("scipy" ,),
43-             valid_options = dict (
44-                 split = ("train" , "test" ),
45-                 year = ("2011" , "2010" ),
46-             ),
39+ NAME  =  "cub200" 
40+ 
41+ CATEGORIES , * _  =  zip (* DatasetInfo .read_categories_file (BUILTIN_DIR  /  f"{ NAME }  ))
42+ 
43+ 
44+ @register_info (NAME ) 
45+ def  _info () ->  Dict [str , Any ]:
46+     return  dict (categories = CATEGORIES )
47+ 
48+ 
49+ @register_dataset (NAME ) 
50+ class  CUB200 (Dataset2 ):
51+     """ 
52+     - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html 
53+     """ 
54+ 
55+     def  __init__ (
56+         self ,
57+         root : Union [str , pathlib .Path ],
58+         * ,
59+         split : str  =  "train" ,
60+         year : str  =  "2011" ,
61+         skip_integrity_check : bool  =  False ,
62+     ) ->  None :
63+         self ._split  =  self ._verify_str_arg (split , "split" , ("train" , "test" ))
64+         self ._year  =  self ._verify_str_arg (year , "year" , ("2010" , "2011" ))
65+ 
66+         self ._categories  =  _info ()["categories" ]
67+ 
68+         super ().__init__ (
69+             root ,
70+             # TODO: this will only be available after https://github.com/pytorch/vision/pull/5473 
71+             # dependencies=("scipy",), 
72+             skip_integrity_check = skip_integrity_check ,
4773        )
4874
49-     def  resources (self ,  config :  DatasetConfig ) ->  List [OnlineResource ]:
50-         if  config . year  ==  "2011" :
75+     def  _resources (self ) ->  List [OnlineResource ]:
76+         if  self . _year  ==  "2011" :
5177            archive  =  HttpResource (
5278                "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" ,
5379                sha256 = "0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081" ,
@@ -59,7 +85,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5985                preprocess = "decompress" ,
6086            )
6187            return  [archive , segmentations ]
62-         else :  # config.year  == "2010" 
88+         else :  # self._year  == "2010" 
6389            split  =  HttpResource (
6490                "http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz" ,
6591                sha256 = "aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428" ,
@@ -90,12 +116,12 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
90116        else :
91117            return  None 
92118
93-     def  _2011_filter_split (self , row : List [str ],  * ,  split :  str ) ->  bool :
119+     def  _2011_filter_split (self , row : List [str ]) ->  bool :
94120        _ , split_id  =  row 
95121        return  {
96122            "0" : "test" ,
97123            "1" : "train" ,
98-         }[split_id ] ==  split 
124+         }[split_id ] ==  self . _split 
99125
100126    def  _2011_segmentation_key (self , data : Tuple [str , Any ]) ->  str :
101127        path  =  pathlib .Path (data [0 ])
@@ -149,17 +175,12 @@ def _prepare_sample(
149175        return  dict (
150176            prepare_ann_fn (anns_data , image .image_size ),
151177            image = image ,
152-             label = Label (int (pathlib .Path (path ).parent .name .rsplit ("." , 1 )[0 ]), categories = self .categories ),
178+             label = Label (int (pathlib .Path (path ).parent .name .rsplit ("." , 1 )[0 ]), categories = self ._categories ),
153179        )
154180
155-     def  _make_datapipe (
156-         self ,
157-         resource_dps : List [IterDataPipe ],
158-         * ,
159-         config : DatasetConfig ,
160-     ) ->  IterDataPipe [Dict [str , Any ]]:
181+     def  _datapipe (self , resource_dps : List [IterDataPipe ]) ->  IterDataPipe [Dict [str , Any ]]:
161182        prepare_ann_fn : Callable 
162-         if  config . year  ==  "2011" :
183+         if  self . _year  ==  "2011" :
163184            archive_dp , segmentations_dp  =  resource_dps 
164185            images_dp , split_dp , image_files_dp , bounding_boxes_dp  =  Demultiplexer (
165186                archive_dp , 4 , self ._2011_classify_archive , drop_none = True , buffer_size = INFINITE_BUFFER_SIZE 
@@ -171,7 +192,7 @@ def _make_datapipe(
171192            )
172193
173194            split_dp  =  CSVParser (split_dp , dialect = "cub200" )
174-             split_dp  =  Filter (split_dp , functools . partial ( self ._2011_filter_split ,  split = config . split ) )
195+             split_dp  =  Filter (split_dp , self ._2011_filter_split )
175196            split_dp  =  Mapper (split_dp , getitem (0 ))
176197            split_dp  =  Mapper (split_dp , image_files_map .get )
177198
@@ -188,10 +209,10 @@ def _make_datapipe(
188209            )
189210
190211            prepare_ann_fn  =  self ._2011_prepare_ann 
191-         else :  # config.year  == "2010" 
212+         else :  # self._year  == "2010" 
192213            split_dp , images_dp , anns_dp  =  resource_dps 
193214
194-             split_dp  =  Filter (split_dp , path_comparator ("name" , f"{ config . split }  ))
215+             split_dp  =  Filter (split_dp , path_comparator ("name" , f"{ self . _split }  ))
195216            split_dp  =  LineReader (split_dp , decode = True , return_path = False )
196217            split_dp  =  Mapper (split_dp , self ._2010_split_key )
197218
@@ -217,11 +238,19 @@ def _make_datapipe(
217238        )
218239        return  Mapper (dp , functools .partial (self ._prepare_sample , prepare_ann_fn = prepare_ann_fn ))
219240
220-     def  _generate_categories (self , root : pathlib .Path ) ->  List [str ]:
221-         config  =  self .info .make_config (year = "2011" )
222-         resources  =  self .resources (config )
241+     def  __len__ (self ) ->  int :
242+         return  {
243+             ("train" , "2010" ): 3_000 ,
244+             ("test" , "2010" ): 3_033 ,
245+             ("train" , "2011" ): 5_994 ,
246+             ("test" , "2011" ): 5_794 ,
247+         }[(self ._split , self ._year )]
248+ 
249+     def  _generate_categories (self ) ->  List [str ]:
250+         self ._year  =  "2011" 
251+         resources  =  self ._resources ()
223252
224-         dp  =  resources [0 ].load (root )
253+         dp  =  resources [0 ].load (self . _root )
225254        dp  =  Filter (dp , path_comparator ("name" , "classes.txt" ))
226255        dp  =  CSVDictParser (dp , fieldnames = ("label" , "category" ), dialect = "cub200" )
227256
0 commit comments