Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
unicode_csv_reader,
)
from torch.utils.data import functional_datapipe, IterDataPipe
from torch.utils.data.datapipes.utils.common import StreamWrapper
import codecs
try:
import defusedxml.ElementTree as ET
Expand All @@ -33,6 +34,25 @@ def _clean_xml_file(f_xml):
fd_txt.write(e.text.strip() + '\n')


def _clean_inner_xml_file(outfile, stream):
"""Accepts an output filename and a stream of the byte contents of an XML file
and writes the cleaned contents to a new file on disk.

Args:
outfile: the path to which the modified stream should be written
stream: the byte datapipe of the contents of the XML file

Returns: the path to the newly-written file and the new StreamWrapper for appropriate caching
"""
os.makedirs(os.path.dirname(outfile), exist_ok=True)
with codecs.open(outfile, mode='w', encoding='utf-8') as fd_txt:
root = ET.fromstring(stream.read().decode("utf-8"))[0]
for doc in root.findall('doc'):
for e in doc.findall('seg'):
fd_txt.write(e.text.strip() + '\n')
return outfile, StreamWrapper(open(outfile, "rb"))


def _clean_tags_file(f_orig):
xml_tags = [
'<url', '<keywords', '<talkid', '<description', '<reviewer',
Expand All @@ -50,6 +70,57 @@ def _clean_tags_file(f_orig):
fd_txt.write(line.strip() + '\n')


def _clean_inner_tags_file(outfile, stream):
"""Accepts an output filename and a stream of the byte contents of a tags file
and writes the cleaned contents to a new file on disk.

Args:
outfile: the path to which the modified stream should be written
stream: the byte datapipe of the contents of the tags file

Returns: the path to the newly-written file and the new StreamWrapper for appropriate caching
"""
xml_tags = [
'<url', '<keywords', '<talkid', '<description', '<reviewer',
'<translator', '<title', '<speaker', '<doc', '</doc'
]
os.makedirs(os.path.dirname(outfile), exist_ok=True)
with codecs.open(outfile, mode='w', encoding='utf-8') as fd_txt:
for line in stream.readlines():
if not any(tag in line.decode("utf-8") for tag in xml_tags):
# TODO: Fix utf-8 next line mark
# fd_txt.write(l.strip() + '\n')
# fd_txt.write(l.strip() + u"\u0085")
# fd_txt.write(l.lstrip())
fd_txt.write(line.decode("utf-8").strip() + '\n')
return outfile, StreamWrapper(open(outfile, "rb"))


def _rewrite_text_file(outfile, stream):
"""Accepts an output filename and a stream of the byte contents of a text file
and writes the cleaned contents to a new file on disk.

Args:
outfile: the path to which the modified stream should be written
stream: the byte datapipe of the contents of the text file

Returns: the path to the newly-written file and the new StreamWrapper for appropriate caching
"""
os.makedirs(os.path.dirname(outfile), exist_ok=True)
with open(outfile, 'w', encoding='utf-8') as f:
for line in stream.readlines():
f.write(line.decode("utf-8") + "\n")
return outfile, StreamWrapper(open(outfile, "rb"))


def _clean_files(outfile, fname, stream):
if 'xml' in fname:
return _clean_inner_xml_file(outfile, stream)
elif "tags" in fname:
return _clean_inner_tags_file(outfile, stream)
return _rewrite_text_file(outfile, stream)


def _create_data_from_json(data_path):
with open(data_path) as json_file:
raw_json_data = json.load(json_file)['data']
Expand Down
177 changes: 101 additions & 76 deletions torchtext/datasets/iwslt2016.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from torchtext._internal.module_utils import is_module_available

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper

import os
from torchtext.utils import (download_from_url, extract_archive)
from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_clean_xml_file,
_clean_tags_file,
_read_text_iterator,
_clean_files,
_create_dataset_directory,
)
from torchtext.data.datasets_utils import _create_dataset_directory

URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8'

_PATH = '2016-01.tgz'

MD5 = 'c393ed3fc2a1b0f004b3331043f615ae'

SUPPORTED_DATASETS = {
'URL': 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8',
'_PATH': '2016-01.tgz',
'MD5': 'c393ed3fc2a1b0f004b3331043f615ae',
'valid_test': ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'],
'language_pair': {
'en': ['ar', 'de', 'fr', 'cs'],
Expand All @@ -23,12 +26,8 @@
'cs': ['en'],
},
'year': 16,

}

URL = SUPPORTED_DATASETS['URL']
MD5 = SUPPORTED_DATASETS['MD5']

NUM_LINES = {
'train': {
'train': {
Expand Down Expand Up @@ -125,28 +124,22 @@
('cs', 'en'): ['tst2014']
}


def _construct_filenames(filename, languages):
filenames = []
for lang in languages:
filenames.append(filename + "." + lang)
return filenames


def _construct_filepaths(paths, src_filename, tgt_filename):
src_path = None
tgt_path = None
for p in paths:
src_path = p if src_filename in p else src_path
tgt_path = p if tgt_filename in p else tgt_path
return (src_path, tgt_path)
DATASET_NAME = "IWSLT2016"


DATASET_NAME = "IWSLT2016"
def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename):
cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath)
cache_inner_decompressed_dp = FileOpener(cache_inner_decompressed_dp, mode="b").read_from_tar()
cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(
lambda x: os.path.basename(uncleaned_filename) in x[0])
cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(
lambda x: _clean_files(full_filepath, x[0], x[1]))
cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
return cache_inner_decompressed_dp


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'valid', 'test'))
@_wrap_split_argument(("train", "valid", "test"))
def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'):
"""IWSLT2016 dataset

Expand Down Expand Up @@ -182,14 +175,11 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
Examples:
>>> from torchtext.datasets import IWSLT2016
>>> train_iter, valid_iter, test_iter = IWSLT2016()
>>> src_sentence, tgt_sentence = next(train_iter)
>>> src_sentence, tgt_sentence = next(iter(train_iter))

"""
num_lines_set_identifier = {
'train': 'train',
'valid': valid_set,
'test': test_set
}
if not is_module_available("torchdata"):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")

if not isinstance(language_pair, list) and not isinstance(language_pair, tuple):
raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair)))
Expand Down Expand Up @@ -225,50 +215,85 @@ def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de
src_eval, tgt_eval = valid_filenames
src_test, tgt_test = test_filenames

extracted_files = [] # list of paths to the extracted files
dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'],
path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5')
extracted_dataset_tar = extract_archive(dataset_tar)
# IWSLT dataset's url downloads a multilingual tgz.
# We need to take an extra step to pick out the specific language pair from it.
src_language = train_filenames[0].split(".")[-1]
tgt_language = train_filenames[1].split(".")[-1]
uncleaned_train_filenames = ('train.tags.{}-{}.{}'.format(src_language, tgt_language, src_language),
'train.tags.{}-{}.{}'.format(src_language, tgt_language, tgt_language))
uncleaed_valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language),
'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language))
uncleaned_test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language),
'IWSLT{}.TED.{}.{}-{}.{}.xml'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language))

uncleaned_src_train, uncleaned_tgt_train = uncleaned_train_filenames
uncleaned_src_eval, uncleaned_tgt_eval = uncleaed_valid_filenames
uncleaned_src_test, uncleaned_tgt_test = uncleaned_test_filenames

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH),
hash_dict={os.path.join(root, _PATH): MD5},
hash_type="md5"
)
cache_compressed_dp = GDriveReader(cache_compressed_dp)
cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True)

languages = "-".join([src_language, tgt_language])

iwslt_tar = '{}/{}/texts/{}/{}/{}.tgz'
iwslt_tar = iwslt_tar.format(
root, SUPPORTED_DATASETS['_PATH'].split(".")[0], src_language, tgt_language, languages)
extracted_dataset_tar = extract_archive(iwslt_tar)
extracted_files.extend(extracted_dataset_tar)

# Clean the xml and tag file in the archives
file_archives = []
for fname in extracted_files:
if 'xml' in fname:
_clean_xml_file(fname)
file_archives.append(os.path.splitext(fname)[0])
elif "tags" in fname:
_clean_tags_file(fname)
file_archives.append(fname.replace('.tags', ''))
else:
file_archives.append(fname)

data_filenames = {
"train": _construct_filepaths(file_archives, src_train, tgt_train),
"valid": _construct_filepaths(file_archives, src_eval, tgt_eval),
"test": _construct_filepaths(file_archives, src_test, tgt_test)
# We create the whole filepath here, but only check for the literal filename in the filter
# because we're lazily extracting from the outer tarfile. Thus,
# /root/2016-01/texts/.../src-tgt.tgz will never be in /root/2016-01.tgz/texts/.../src-tgt.tgz
inner_iwslt_tar = os.path.join(root, os.path.splitext(_PATH)[0], "texts", src_language, tgt_language, languages) + ".tgz"

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0])
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

file_path_by_lang_and_split = {
src_language: {
"train": src_train,
"valid": src_eval,
"test": src_test,
},
tgt_language: {
"train": tgt_train,
"valid": tgt_eval,
"test": tgt_test,
}
}

for key in data_filenames.keys():
if len(data_filenames[key]) == 0 or data_filenames[key] is None:
raise FileNotFoundError(
"Files are not found for data type {}".format(key))
uncleaned_filenames = {
src_language: {
"train": uncleaned_src_train,
"valid": uncleaned_src_eval,
"test": uncleaned_src_test,
},
tgt_language: {
"train": uncleaned_tgt_train,
"valid": uncleaned_tgt_eval,
"test": uncleaned_tgt_test,
}
}

src_filename = file_path_by_lang_and_split[src_language][split]
uncleaned_src_filename = uncleaned_filenames[src_language][split]

# We create the whole filepath here, but only check for the literal filename in the filter
# because we're lazily extracting from the outer tarfile.
full_src_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, src_filename)

cache_inner_src_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_src_filepath, uncleaned_src_filename)

tgt_filename = file_path_by_lang_and_split[tgt_language][split]
uncleaned_tgt_filename = uncleaned_filenames[tgt_language][split]

# We create the whole filepath here, but only check for the literal filename in the filter
# because we're lazily extracting from the outer tarfile.
full_tgt_filepath = os.path.join(root, "2016-01/texts/", src_language, tgt_language, languages, tgt_filename)

cache_inner_tgt_decompressed_dp = _filter_clean_cache(cache_decompressed_dp, full_tgt_filepath, uncleaned_tgt_filename)

src_data_iter = _read_text_iterator(data_filenames[split][0])
tgt_data_iter = _read_text_iterator(data_filenames[split][1])
tgt_data_dp = FileOpener(cache_inner_tgt_decompressed_dp, mode="r")
src_data_dp = FileOpener(cache_inner_src_decompressed_dp, mode="r")

def _iter(src_data_iter, tgt_data_iter):
for item in zip(src_data_iter, tgt_data_iter):
yield item
src_lines = src_data_dp.readlines(return_path=False, strip_newline=False)
tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False)

return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter))
return src_lines.zip(tgt_lines)