Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions torchtext/data/data_pipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

import csv
import json

from torch.utils.data import IterDataPipe, functional_datapipe
from torchtext._download_hooks import _get_response_from_google_drive, _stream_response


@ functional_datapipe('parse_json_files')
class JSONParserIterDataPipe(IterDataPipe):
def __init__(self, source_datapipe):
self.source_datapipe = source_datapipe

def __iter__(self):
for _, stream in self.source_datapipe:
raw_json_data = json.load(stream)['data']
for layer1 in raw_json_data:
for layer2 in layer1['paragraphs']:
for layer3 in layer2['qas']:
_context, _question = layer2['context'], layer3['question']
_answers = [item['text'] for item in layer3['answers']]
_answer_start = [item['answer_start'] for item in layer3['answers']]
if len(_answers) == 0:
_answers = [""]
_answer_start = [-1]
yield (_context, _question, _answers, _answer_start)


class GDriveReaderDataPipe(IterDataPipe):
def __init__(self, source_datapipe):
self.source_datapipe = source_datapipe

def __iter__(self):
for url in self.source_datapipe:
response, filename = _get_response_from_google_drive(url)
yield (filename, response.raw)
4 changes: 3 additions & 1 deletion torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .yelpreviewfull import YelpReviewFull
from .yelpreviewpolarity import YelpReviewPolarity
from .multi30k import Multi30k
from .sst2 import SST2

DATASETS = {
'AG_NEWS': AG_NEWS,
Expand All @@ -40,7 +41,8 @@
'YahooAnswers': YahooAnswers,
'YelpReviewFull': YelpReviewFull,
'YelpReviewPolarity': YelpReviewPolarity,
'Multi30k': Multi30k
'Multi30k': Multi30k,
'SST2': SST2
}

URLS = {}
Expand Down
19 changes: 8 additions & 11 deletions torchtext/datasets/ag_news.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from torchtext.utils import (
download_from_url,
)
from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_create_dataset_directory,
_create_data_from_csv,
)

from datapipes.iter import (
CSVParser,
HttpReader
)

import os

URL = {
Expand All @@ -32,9 +33,5 @@
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'test'))
def AG_NEWS(root, split):
path = download_from_url(URL[split], root=root,
path=os.path.join(root, split + ".csv"),
hash_value=MD5[split],
hash_type='md5')
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_csv(path))
# TODO Caching mechanism
return HttpReader([URL[split]]).parse_csv_files().map(lambda t: (int(t[1]), ' '.join(t[2:])))
10 changes: 5 additions & 5 deletions torchtext/datasets/amazonreviewfull.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
)
import os
import logging
from torchtext.data.data_pipes import GDriveReaderDataPipe as GDriveReader

from torch.utils.data.datapipes.iter import LoadFilesFromDisk
URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'

MD5 = '57d28bd5d930e772930baddf36641c7c'
Expand Down Expand Up @@ -37,8 +39,6 @@
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'test'))
def AmazonReviewFull(root, split):
path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]),
_EXTRACTED_FILES_MD5[split], hash_type="md5")
logging.info('Creating {} data'.format(split))
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_csv(path))
saver_dp = GDriveReader([URL]).map(lambda x: (x[0],x[1].read())).save_to_disk(filepath_fn=lambda x: os.path.join(root, x))
extracted_files = LoadFilesFromDisk(saver_dp).read_from_tar()
return extracted_files.filter(lambda x: split in x[0]).parse_csv_files().map(lambda t: (int(t[1]), ' '.join(t[2:])))
22 changes: 14 additions & 8 deletions torchtext/datasets/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_download_extract_validate,
_create_dataset_directory,
_create_data_from_csv,
)
import os
import logging
from datapipes.iter import(
CSVParser,
ReadFilesFromTar,
HttpReader,
)

from torchtext.data.data_pipes import GDriveReaderDataPipe as GDriveReader

from torch.utils.data.datapipes.iter import LoadFilesFromDisk


URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM'

Expand Down Expand Up @@ -37,8 +44,7 @@
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'test'))
def AmazonReviewPolarity(root, split):
path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]),
_EXTRACTED_FILES_MD5[split], hash_type="md5")
logging.info('Creating {} data'.format(split))
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_csv(path))
saver_dp = GDriveReader([URL]).map(lambda x: (x[0],x[1].read())).save_to_disk(filepath_fn=lambda x: os.path.join(root, x))
extracted_files = LoadFilesFromDisk(saver_dp).read_from_tar()
return extracted_files.filter(lambda x: split in x[0]).parse_csv_files().map(lambda t: (int(t[1]), ' '.join(t[2:])))

29 changes: 14 additions & 15 deletions torchtext/datasets/imdb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from torchtext.utils import download_from_url, extract_archive
from torchtext.data.datasets_utils import _RawTextIterableDataset
from torchtext.utils import download_from_url
from torchtext.data.datasets_utils import _wrap_split_argument
from torchtext.data.datasets_utils import _add_docstring_header
from torchtext.data.datasets_utils import _create_dataset_directory
import io
import os
from pathlib import Path

from datapipes.iter import (
IterableAsDataPipe,
ReadFilesFromTar,
HttpReader,
Saver,
)

from torch.utils.data.datapipes.iter import LoadFilesFromDisk
URL = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'

MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
Expand All @@ -24,15 +31,7 @@
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'test'))
def IMDB(root, split):
def generate_imdb_data(key, extracted_files):
for fname in extracted_files:
*_, split, label, file = Path(fname).parts

if key == split and (label in ['pos', 'neg']):
with io.open(fname, encoding="utf8") as f:
yield label, f.read()
dataset_tar = download_from_url(URL, root=root,
hash_value=MD5, hash_type='md5')
extracted_files = extract_archive(dataset_tar)
iterator = generate_imdb_data(split, extracted_files)
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], iterator)
saver_dp = HttpReader([URL]).map(lambda x: (x[0],x[1].read())).save_to_disk(filepath_fn=lambda x: os.path.join(root, os.path.basename(x)))
extracted_files = LoadFilesFromDisk(saver_dp).read_from_tar()
return extracted_files.filter(lambda x: Path(x[0]).parts[-3] == split and Path(x[0]).parts[-2]
in ['pos', 'neg']).map(lambda x: (Path(x[0]).parts[-2], x[1].read().decode('utf-8')))
15 changes: 9 additions & 6 deletions torchtext/datasets/squad1.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from torchtext.utils import download_from_url
from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_create_dataset_directory,
_create_data_from_json,
)
import os
from torchtext.data.data_pipes import JSONParserIterDataPipe
from datapipes.iter import (
HttpReader,
)
from torch.utils.data.datapipes.iter import LoadFilesFromDisk

URL = {
'train': "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json",
'dev': "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
Expand All @@ -29,6 +33,5 @@
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'dev'))
def SQuAD1(root, split):
extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5')
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_json(extracted_files))
saver_dp = HttpReader([URL[split]]).map(lambda x: (x[0], x[1].read())).save_to_disk(filepath_fn=lambda x: os.path.join(root, os.path.basename(x)))
return LoadFilesFromDisk(saver_dp).parse_json_files()
14 changes: 8 additions & 6 deletions torchtext/datasets/squad2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from torchtext.utils import download_from_url
from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_create_dataset_directory,
_create_data_from_json,
)
import os
from datapipes.iter import (
HttpReader,
)
from torchtext.data.data_pipes import JSONParserIterDataPipe
from torch.utils.data.datapipes.iter import LoadFilesFromDisk
URL = {
'train': "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json",
'dev': "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json",
Expand All @@ -29,6 +32,5 @@
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'dev'))
def SQuAD2(root, split):
extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5')
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_json(extracted_files))
saver_dp = HttpReader([URL[split]]).map(lambda x: (x[0], x[1].read())).save_to_disk(filepath_fn=lambda x: os.path.join(root, os.path.basename(x)))
return LoadFilesFromDisk(saver_dp).parse_json_files()
29 changes: 29 additions & 0 deletions torchtext/datasets/sst2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from torchtext.datasets.ag_news import DATASET_NAME, NUM_LINES
from torchtext.data.datasets_utils import (
_create_dataset_directory,
_wrap_split_argument,
)
import os
from datapipes.iter import(
CSVParser,
ReadFilesFromZip,
HttpReader,
IterableAsDataPipe,
)


from torch.utils.data.datapipes.iter import LoadFilesFromDisk
NUM_LINES = {}
MD5 = {}
URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip"

DATASET_NAME = "SST2"


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'dev', 'test'))
def SST2(root, split):
# TODO: not working: cache_dp = IterableAsDataPipe([URL]).on_disk_cache(HttpReader, filepath_fn=lambda x: os.path.join(root, os.path.basename(x)))
saver_dp = HttpReader([URL]).map(lambda x: (x[0], x[1].read())).save_to_disk(filepath_fn=lambda x: os.path.join(root, os.path.basename(x)))
extracted_files = LoadFilesFromDisk(saver_dp).read_from_zip()
return extracted_files.filter(lambda x: split in x[0]).parse_csv_files(skip_header=True, delimiter='\t').map(lambda x: (x[1], int(x[2])))