Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,6 @@ def test_raw_ag_news(self):
self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft')
del train_iter, test_iter

@parameterized.expand(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A n00b question. Why did we remove this and how this change relate to migration of AmazonReviewPolarity to datapipes ?

Copy link
Contributor

@erip erip Jan 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __str__ method under test here was coming from _RawTextIterableDataset so it's somewhat overcome by events.

load_params('raw_datasets.jsonl'),
name_func=_raw_text_custom_name_func)
def test_raw_text_name_property(self, info):
dataset_name = info['dataset_name']
split = info['split']

if dataset_name == 'WMT14':
return
else:
data_iter = torchtext.datasets.DATASETS[dataset_name](split=split)

self.assertEqual(str(data_iter), dataset_name)

@parameterized.expand(
load_params('raw_datasets.jsonl'),
name_func=_raw_text_custom_name_func)
Expand All @@ -63,8 +49,7 @@ def test_raw_text_classification(self, info):
return
else:
data_iter = torchtext.datasets.DATASETS[dataset_name](split=split)
self.assertEqual(len(data_iter), info['NUM_LINES'])
self.assertEqual(hashlib.md5(json.dumps(next(data_iter), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line'])
self.assertEqual(hashlib.md5(json.dumps(next(iter(data_iter)), sort_keys=True).encode('utf-8')).hexdigest(), info['first_line'])
if dataset_name == "AG_NEWS":
self.assertEqual(torchtext.datasets.URLS[dataset_name][split], info['URL'])
self.assertEqual(torchtext.datasets.MD5[dataset_name][split], info['MD5'])
Expand Down
14 changes: 6 additions & 8 deletions torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ def _wrap_split_argument_with_fn(fn, splits):
argspec.args[1] == "split" and
argspec.varargs is None and
argspec.varkw is None and
len(argspec.kwonlyargs) == 0 and
len(argspec.annotations) == 0
len(argspec.kwonlyargs) == 0
):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))

Expand Down Expand Up @@ -240,23 +239,22 @@ def new_fn(fn):


def _create_dataset_directory(dataset_name):
def decorator(func):
argspec = inspect.getfullargspec(func)
def decorator(fn):
argspec = inspect.getfullargspec(fn)
if not (argspec.args[0] == "root" and
argspec.args[1] == "split" and
argspec.varargs is None and
argspec.varkw is None and
len(argspec.kwonlyargs) == 0 and
len(argspec.annotations) == 0
len(argspec.kwonlyargs) == 0
):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))

@functools.wraps(func)
@functools.wraps(fn)
def wrapper(root=_CACHE_DIR, *args, **kwargs):
new_root = os.path.join(root, dataset_name)
if not os.path.exists(new_root):
os.makedirs(new_root)
return func(root=new_root, *args, **kwargs)
return fn(root=new_root, *args, **kwargs)

return wrapper

Expand Down
36 changes: 21 additions & 15 deletions torchtext/datasets/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from torchtext._internal.module_utils import is_module_available
from typing import Union, Tuple
if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we also need to import HttpReader so that we can use this dataset internally within FB by swapping out public URLs with internal URLs. This is what I'm doing for the SST2 dataset implementation (https://github.com/Nayef211/text/blob/main/torchtext/experimental/datasets/sst2.py#L15-L17)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch @Nayef211. Note that in this case we need GDriveReader. Since this is not yet implementation internally, let fix this part once we import this PR into fb-code.

from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_download_extract_validate,
_create_dataset_directory,
_create_data_from_csv,
)

import os
import logging

URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM'

Expand All @@ -25,20 +27,24 @@
'test': f'{os.sep}'.join(['amazon_review_polarity_csv', 'test.csv']),
}

_EXTRACTED_FILES_MD5 = {
'train': "520937107c39a2d1d1f66cd410e9ed9e",
'test': "f4c8bded2ecbde5f996b675db6228f16"
}

DATASET_NAME = "AmazonReviewPolarity"


@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'test'))
def AmazonReviewPolarity(root, split):
path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]),
_EXTRACTED_FILES_MD5[split], hash_type="md5")
logging.info('Creating {} data'.format(split))
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_csv(path))
@_wrap_split_argument(("train", "test"))
def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
# TODO Remove this after removing conditional dependency
if not is_module_available("torchdata"):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codecov shows missing coverage here. Is there any way we could mock is_module_available and test that the exception is thrown?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think once there's a stable release of torchdata published that this check will disappear. Is it still worth the test since it'll be removed soon anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, this is temporary. We can probably skip adding test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me.


url_dp = IterableWrapper([URL])
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
)
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
cache_dp = FileOpener(cache_dp, mode="b")
extracted_files = cache_dp.read_from_tar()
filter_extracted_files = extracted_files.filter(lambda x: split in x[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parmeet I didn't catch this earlier, but I strongly recommend updating this logic to include the extension of the split (i.e. .csv or .tsv). This caused bugs in the SST2 implemenation in the past. Look at #1444 for reference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to do lambda x: _EXTRACTED_FILES[split] in x[0] which is what @erip is doing in #1499

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Nayef211 for catching this. Yes, I remember this breakage earlier, and completely forgot to take care of this. Will do PR to fix it.

return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:])))