diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index d6daf3a38f..fa363150ac 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -11,6 +11,7 @@ unicode_csv_reader, ) from torch.utils.data import functional_datapipe, IterDataPipe +from torch.utils.data.datapipes.utils.common import StreamWrapper import codecs try: import defusedxml.ElementTree as ET @@ -33,6 +34,25 @@ def _clean_xml_file(f_xml): fd_txt.write(e.text.strip() + '\n') +def _clean_inner_xml_file(outfile, stream): + """Accepts an output filename and a stream of the byte contents of an XML file + and writes the cleaned contents to a new file on disk. + + Args: + outfile: the path to which the modified stream should be written + stream: the byte datapipe of the contents of the XML file + + Returns: the path to the newly-written file and the new StreamWrapper for appropriate caching + """ + os.makedirs(os.path.dirname(outfile), exist_ok=True) + with codecs.open(outfile, mode='w', encoding='utf-8') as fd_txt: + root = ET.fromstring(stream.read().decode("utf-8"))[0] + for doc in root.findall('doc'): + for e in doc.findall('seg'): + fd_txt.write(e.text.strip() + '\n') + return outfile, StreamWrapper(open(outfile, "rb")) + + def _clean_tags_file(f_orig): xml_tags = [ '>> from torchtext.datasets import IWSLT2016 >>> train_iter, valid_iter, test_iter = IWSLT2016() - >>> src_sentence, tgt_sentence = next(train_iter) + >>> src_sentence, tgt_sentence = next(iter(train_iter)) """ - num_lines_set_identifier = { - 'train': 'train', - 'valid': valid_set, - 'test': test_set - } + if not is_module_available("torchdata"): + raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`") if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) @@ -225,50 +215,85 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames - extracted_files = [] # list of paths to the extracted files - dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'], - path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5') - extracted_dataset_tar = extract_archive(dataset_tar) - # IWSLT dataset's url downloads a multilingual tgz. - # We need to take an extra step to pick out the specific language pair from it. - src_language = train_filenames[0].split(".")[-1] - tgt_language = train_filenames[1].split(".")[-1] + uncleaned_train_filenames = ('train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language), + 'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language)) + uncleaed_valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language)) + uncleaned_test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), + 'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language)) + + uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames + uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames + uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _PATH), + hash_dict={os.path.join(root, _PATH): MD5}, + hash_type="md5" + ) + cache_compressed_dp = GDriveReader(cache_compressed_dp) + cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) + languages = "-".join([src_language, tgt_language]) - iwslt_tar = '{}/{}/texts/{}/{}/{}.tgz' - iwslt_tar = iwslt_tar.format( - root, SUPPORTED_DATASETS['_PATH'].split(".")[0], src_language, tgt_language, languages) - extracted_dataset_tar = extract_archive(iwslt_tar) - extracted_files.extend(extracted_dataset_tar) - - # Clean the xml and tag file in the archives - file_archives = [] - for fname in extracted_files: - if 'xml' in fname: - _clean_xml_file(fname) - file_archives.append(os.path.splitext(fname)[0]) - elif "tags" in fname: - _clean_tags_file(fname) - file_archives.append(fname.replace('.tags', '')) - else: - file_archives.append(fname) - - data_filenames = { - "train": _construct_filepaths(file_archives, src_train, tgt_train), - "valid": _construct_filepaths(file_archives, src_eval, tgt_eval), - "test": _construct_filepaths(file_archives, src_test, tgt_test) + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. Thus, + # /root/2016-01/texts/.../src-tgt.tgz will never be in /root/2016-01.tgz/texts/.../src-tgt.tgz + inner_iwslt_tar = os.path.join(root, os.path.splitext(_PATH)[0], "texts", src_language, tgt_language, languages) + ".tgz" + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0]) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + file_path_by_lang_and_split = { + src_language: { + "train": src_train, + "valid": src_eval, + "test": src_test, + }, + tgt_language: { + "train": tgt_train, + "valid": tgt_eval, + "test": tgt_test, + } } - for key in data_filenames.keys(): - if len(data_filenames[key]) == 0 or data_filenames[key] is None: - raise FileNotFoundError( - "Files are not found for data type {}".format(key)) + uncleaned_filenames = { + src_language: { + "train": uncleaned_src_train, + "valid": uncleaned_src_eval, + "test": uncleaned_src_test, + }, + tgt_language: { + "train": uncleaned_tgt_train, + "valid": uncleaned_tgt_eval, + "test": uncleaned_tgt_test, + } + } + + src_filename = file_path_by_lang_and_split[src_language][split] + uncleaned_src_filename = uncleaned_filenames[src_language][split] + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. + full_src_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename) + + cache_inner_src_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_src_filepath, uncleaned_src_filename) + + tgt_filename = file_path_by_lang_and_split[tgt_language][split] + uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split] + + # We create the whole filepath here, but only check for the literal filename in the filter + # because we're lazily extracting from the outer tarfile. + full_tgt_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename) + + cache_inner_tgt_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_tgt_filepath, uncleaned_tgt_filename) - src_data_iter = _read_text_iterator(data_filenames[split][0]) - tgt_data_iter = _read_text_iterator(data_filenames[split][1]) + tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r") + src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r") - def _iter(src_data_iter, tgt_data_iter): - for item in zip(src_data_iter, tgt_data_iter): - yield item + src_lines = src_data_dp.readlines(return_path=False, strip_newline=False) + tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False) - return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter)) + return src_lines.zip(tgt_lines)