|
| 1 | +from torchtext._internal.module_utils import is_module_available |
| 2 | + |
| 3 | +if is_module_available("torchdata"): |
| 4 | + from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper |
| 5 | + |
1 | 6 | from torchtext.data.datasets_utils import ( |
2 | | - _RawTextIterableDataset, |
3 | 7 | _wrap_split_argument, |
4 | 8 | _add_docstring_header, |
5 | | - _download_extract_validate, |
6 | 9 | _create_dataset_directory, |
7 | | - _create_data_from_csv, |
8 | 10 | ) |
| 11 | + |
9 | 12 | import os |
10 | | -import logging |
11 | 13 |
|
12 | 14 | URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA' |
13 | 15 |
|
|
35 | 37 |
|
36 | 38 | @_add_docstring_header(num_lines=NUM_LINES, num_classes=5) |
37 | 39 | @_create_dataset_directory(dataset_name=DATASET_NAME) |
38 | | -@_wrap_split_argument(('train', 'test')) |
| 40 | +@_wrap_split_argument(("train", "test")) |
39 | 41 | def AmazonReviewFull(root, split): |
40 | | - path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]), |
41 | | - _EXTRACTED_FILES_MD5[split], hash_type="md5") |
42 | | - logging.info('Creating {} data'.format(split)) |
43 | | - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], |
44 | | - _create_data_from_csv(path)) |
| 42 | + url_dp = IterableWrapper([URL]) |
| 43 | + |
| 44 | + cache_dp = url_dp.on_disk_cache( |
| 45 | + filepath_fn=lambda x: os.path.join(root, _PATH), |
| 46 | + hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" |
| 47 | + ) |
| 48 | + cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) |
| 49 | + cache_dp = FileOpener(cache_dp, mode="b") |
| 50 | + |
| 51 | + extracted_files = cache_dp.read_from_tar() |
| 52 | + |
| 53 | + filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0]) |
| 54 | + |
| 55 | + return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) |
0 commit comments