Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 37 additions & 39 deletions torchtext/datasets/multi30k.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from torchtext._internal.module_utils import is_module_available
from typing import Union, Tuple

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper

import os
from torchtext.data.datasets_utils import (
_download_extract_validate,
_RawTextIterableDataset,
_wrap_split_argument,
_create_dataset_directory,
_read_text_iterator,
)

URL = {
Expand All @@ -19,28 +22,10 @@
'test': '0681be16a532912288a91ddd573594fbdd57c0fbb81486eff7c55247e35326c2',
}

_EXTRACTED_FILES_INFO = {
'train': {
'file_prefix': 'train',
'md5': {
'de': '695df46f6fd14567e69970408a2c129a50e778a910ecb1585a92eb25b2c7accc',
'en': '4b4d37e774976ef44fecca1738cdeb3b3ba384851a59a755b9c5e6aa7d87b13c',
},
},
'valid': {
'file_prefix': 'val',
'md5': {
'de': 'fd0fc009db2446cfc12d96a382aff0d3122cb47577b352d0f7e0bb3a38e2e552',
'en': '40cd20974079d9afb0e3d27c659a8e059cc2fcf850b4bc23ede13fc36dd8a865',
},
},
'test': {
'file_prefix': 'test',
'md5': {
'de': 'c1d2f544471a7387e37d15f1adf075ff7d6fe57a30840bb969281ae102d24cb1',
'en': '399a4382932c1aadd3ceb9bef1008d388a64c76d4ae4e9d4728c6f4301cac182',
},
},
_PREFIX = {
'train': 'train',
'valid': 'val',
'test': 'test',
}

NUM_LINES = {
Expand All @@ -53,8 +38,8 @@


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'valid', 'test'))
def Multi30k(root, split, language_pair=('de', 'en')):
@_wrap_split_argument(("train", "valid", "test"))
def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] = ('de', 'en')):
"""Multi30k dataset

Reference: http://www.statmt.org/wmt16/multimodal-task.html#task1
Expand All @@ -68,18 +53,31 @@ def Multi30k(root, split, language_pair=('de', 'en')):
assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively'
assert (tuple(sorted(language_pair)) == ('de', 'en')), "language_pair must be either ('de','en') or ('en', 'de')"

downloaded_file = os.path.basename(URL[split])
if not is_module_available("torchdata"):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")

url_dp = IterableWrapper([URL[split]])

cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(URL[split])),
hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]},
hash_type="sha256"
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

src_cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[0]}"))
src_cache_decompressed_dp = FileOpener(src_cache_decompressed_dp, mode="b").read_from_tar().filter(
lambda x: f"{_PREFIX[split]}.{language_pair[0]}" in x[0])
src_cache_decompressed_dp = src_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

src_path = _download_extract_validate(root, URL[split], MD5[split],
os.path.join(root, downloaded_file),
os.path.join(root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[0]),
_EXTRACTED_FILES_INFO[split]['md5'][language_pair[0]])
trg_path = _download_extract_validate(root, URL[split], MD5[split],
os.path.join(root, downloaded_file),
os.path.join(root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[1]),
_EXTRACTED_FILES_INFO[split]['md5'][language_pair[1]])
tgt_cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[1]}"))
tgt_cache_decompressed_dp = FileOpener(tgt_cache_decompressed_dp, mode="b").read_from_tar().filter(
lambda x: f"{_PREFIX[split]}.{language_pair[1]}" in x[0])
tgt_cache_decompressed_dp = tgt_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

src_data_iter = _read_text_iterator(src_path)
trg_data_iter = _read_text_iterator(trg_path)
src_data_dp = FileOpener(src_cache_decompressed_dp, mode="b").readlines(decode=True, return_path=False, strip_newline=False)
tgt_data_dp = FileOpener(tgt_cache_decompressed_dp, mode="b").readlines(decode=True, return_path=False, strip_newline=False)

return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], zip(src_data_iter, trg_data_iter))
return src_data_dp.zip(tgt_data_dp)