1111import urllib
1212import urllib .error
1313import urllib .request
14+ import warnings
1415import zipfile
1516from typing import Any , Callable , List , Iterable , Optional , TypeVar , Dict , IO , Tuple , Iterator
1617from urllib .parse import urlparse
2425 _is_remote_location_available ,
2526)
2627
27-
2828USER_AGENT = "pytorch/vision"
2929
3030
31- def _urlretrieve (url : str , filename : str , chunk_size : int = 1024 ) -> None :
32- with open (filename , "wb" ) as fh :
33- with urllib .request .urlopen (urllib .request .Request (url , headers = {"User-Agent" : USER_AGENT })) as response :
34- with tqdm (total = response .length ) as pbar :
35- for chunk in iter (lambda : response .read (chunk_size ), "" ):
36- if not chunk :
37- break
38- pbar .update (chunk_size )
39- fh .write (chunk )
31+ def _save_response_content (
32+ content : Iterator [bytes ],
33+ destination : str ,
34+ length : Optional [int ] = None ,
35+ ) -> None :
36+ with open (destination , "wb" ) as fh , tqdm (total = length ) as pbar :
37+ for chunk in content :
38+ # filter out keep-alive new chunks
39+ if not chunk :
40+ continue
41+
42+ fh .write (chunk )
43+ pbar .update (len (chunk ))
44+
45+
46+ def _urlretrieve (url : str , filename : str , chunk_size : int = 1024 * 32 ) -> None :
47+ with urllib .request .urlopen (urllib .request .Request (url , headers = {"User-Agent" : USER_AGENT })) as response :
48+ _save_response_content (iter (lambda : response .read (chunk_size ), b"" ), filename , length = response .length )
4049
4150
4251def gen_bar_updater () -> Callable [[int , int , int ], None ]:
52+ warnings .warn ("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15." )
4353 pbar = tqdm (total = None )
4454
4555 def bar_update (count , block_size , total_size ):
@@ -184,11 +194,20 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
184194 return files
185195
186196
187- def _quota_exceeded (first_chunk : bytes ) -> bool :
197+ def _extract_gdrive_api_response (response , chunk_size : int = 32 * 1024 ) -> Tuple [bytes , Iterator [bytes ]]:
198+ content = response .iter_content (chunk_size )
199+ first_chunk = None
200+ # filter out keep-alive new chunks
201+ while not first_chunk :
202+ first_chunk = next (content )
203+ content = itertools .chain ([first_chunk ], content )
204+
188205 try :
189- return "Google Drive - Quota exceeded" in first_chunk .decode ()
206+ match = re .search ("<title>Google Drive - (?P<api_response>.+?)</title>" , first_chunk .decode ())
207+ api_response = match ["api_response" ] if match is not None else None
190208 except UnicodeDecodeError :
191- return False
209+ api_response = None
210+ return api_response , content
192211
193212
194213def download_file_from_google_drive (file_id : str , root : str , filename : Optional [str ] = None , md5 : Optional [str ] = None ):
@@ -202,70 +221,41 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
202221 """
203222 # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
204223
205- url = "https://docs.google.com/uc?export=download"
206-
207224 root = os .path .expanduser (root )
208225 if not filename :
209226 filename = file_id
210227 fpath = os .path .join (root , filename )
211228
212229 os .makedirs (root , exist_ok = True )
213230
214- if os .path .isfile (fpath ) and check_integrity (fpath , md5 ):
215- print ("Using downloaded and verified file: " + fpath )
216- else :
217- session = requests .Session ()
218-
219- response = session .get (url , params = {"id" : file_id }, stream = True )
220- token = _get_confirm_token (response )
221-
222- if token :
223- params = {"id" : file_id , "confirm" : token }
224- response = session .get (url , params = params , stream = True )
225-
226- # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
227- # with their own API, refer https://github.com/pytorch/vision/issues/2992#issuecomment-730614517.
228- # Should this be fixed at some place in future, one could refactor the following to no longer rely on decoding
229- # the first_chunk of the payload
230- response_content_generator = response .iter_content (32768 )
231- first_chunk = None
232- while not first_chunk : # filter out keep-alive new chunks
233- first_chunk = next (response_content_generator )
234-
235- if _quota_exceeded (first_chunk ):
236- msg = (
237- f"The daily quota of the file { filename } is exceeded and it "
238- f"can't be downloaded. This is a limitation of Google Drive "
239- f"and can only be overcome by trying again later."
240- )
241- raise RuntimeError (msg )
242-
243- _save_response_content (itertools .chain ((first_chunk ,), response_content_generator ), fpath )
244- response .close ()
231+ if check_integrity (fpath , md5 ):
232+ print (f"Using downloaded { 'and verified ' if md5 else '' } file: { fpath } " )
245233
234+ url = "https://drive.google.com/uc"
235+ params = dict (id = file_id , export = "download" )
236+ with requests .Session () as session :
237+ response = session .get (url , params = params , stream = True )
246238
247- def _get_confirm_token (response : requests .models .Response ) -> Optional [str ]:
248- for key , value in response .cookies .items ():
249- if key .startswith ("download_warning" ):
250- return value
239+ for key , value in response .cookies .items ():
240+ if key .startswith ("download_warning" ):
241+ token = value
242+ break
243+ else :
244+ api_response , content = _extract_gdrive_api_response (response )
245+ token = "t" if api_response == "Virus scan warning" else None
251246
252- return None
247+ if token is not None :
248+ response = session .get (url , params = dict (params , confirm = token ), stream = True )
249+ api_response , content = _extract_gdrive_api_response (response )
253250
251+ if api_response == "Quota exceeded" :
252+ raise RuntimeError (
253+ f"The daily quota of the file { filename } is exceeded and it "
254+ f"can't be downloaded. This is a limitation of Google Drive "
255+ f"and can only be overcome by trying again later."
256+ )
254257
255- def _save_response_content (
256- response_gen : Iterator [bytes ],
257- destination : str ,
258- ) -> None :
259- with open (destination , "wb" ) as f :
260- pbar = tqdm (total = None )
261- progress = 0
262-
263- for chunk in response_gen :
264- if chunk : # filter out keep-alive new chunks
265- f .write (chunk )
266- progress += len (chunk )
267- pbar .update (progress - pbar .n )
268- pbar .close ()
258+ _save_response_content (content , fpath )
269259
270260
271261def _extract_tar (from_path : str , to_path : str , compression : Optional [str ]) -> None :
0 commit comments