Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
84 changes: 68 additions & 16 deletions torchtext/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import six
import requests
import csv
import shutil
import os
from tqdm import tqdm


Expand All @@ -24,37 +26,87 @@ def inner(b=1, bsize=1, tsize=None):
return inner


def download_from_url(url, path):
def download_from_url(url, destination):
"""Download file, with logic (from tensor2tensor) for Google Drive"""
def process_response(r):
chunk_size = 16 * 1024
total_size = int(r.headers.get('Content-length', 0))
with open(path, "wb") as file:
with tqdm(total=total_size, unit='B',
unit_scale=1, desc=path.split('/')[-1]) as t:
for chunk in r.iter_content(chunk_size):
if chunk:
file.write(chunk)
t.update(len(chunk))
def process_response(r, first_byte):

# Check if the requested url is ok, i.e. 200 <= status_code < 400
head = requests.head(url)
if not head.ok:
head.raise_for_status()

# Since requests doesn't support local file reading
# we check if protocol is file://
if url.startswith('file://'):
url_no_protocol = url.replace('file://', '', count=1)
if os.path.exists(url_no_protocol):
print('File already exists, no need to download')
return
else:
raise Exception('File not found at %s' % url_no_protocol)

# Don't download if the file exists
if os.path.exists(os.path.expanduser(destination)):
print('File already exists, no need to download')
return

tmp_file = destination + '.part'
first_byte = os.path.getsize(tmp_file) if os.path.exists(tmp_file) else 0
chunk_size = 1024 ** 2 # 1 MB
file_mode = 'ab' if first_byte else 'wb'

# Set headers to resume download from where we've left
headers = {"Range": "bytes=%s-" % first_byte}
r = requests.get(url, headers=headers, stream=True)
file_size = int(r.headers.get('Content-length', -1))
if file_size >= 0:
# Content-length set
file_size += first_byte
total = file_size
else:
# Content-length not set
print('Cannot retrieve Content-length from server')
total = None

print('Download from ' + url)
print('Starting download at %.1fMB' % (first_byte / (10 ** 6)))
print('File size is %.1fMB' % (file_size / (10 ** 6)))

with tqdm(initial=first_byte, total=total, unit_scale=True) as pbar:
with open(tmp_file, file_mode) as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
pbar.update(len(chunk))

# Rename the temp download file to the correct name if fully downloaded
shutil.move(tmp_file, destination)

tmp_file_path = destination + '.part'
first_byte = os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0

# Set headers: this will tell the server to start download from the specified byte
headers = {"Range": "bytes=%s-" % first_byte}

if 'drive.google.com' not in url:
response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
process_response(response)
headers.update({'User-Agent': 'Mozilla/5.0'})
response = requests.get(url, headers=headers, stream=True)
process_response(response, first_byte)
return

print('downloading from Google Drive; may take a few minutes')
confirm_token = None
session = requests.Session()
response = session.get(url, stream=True)
response = session.get(url, headers=headers, stream=True)
for k, v in response.cookies.items():
if k.startswith("download_warning"):
confirm_token = v

if confirm_token:
url = url + "&confirm=" + confirm_token
response = session.get(url, stream=True)
response = session.get(url, headers=headers, stream=True)

process_response(response)
process_response(response, first_byte)


def unicode_csv_reader(unicode_csv_data, **kwargs):
Expand Down