Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions test/datasets/test_iwslt2016.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import random
import string
from collections import defaultdict
from unittest.mock import patch

from parameterized import parameterized
from torchtext.datasets.iwslt2016 import IWSLT2016
from torchtext.data.datasets_utils import _generate_iwslt_files_for_lang_and_split

from ..common.case_utils import TempDirMixin, zip_equal
from ..common.torchtext_test_case import TorchtextTestCase


def _get_mock_dataset(root_dir, split, src, tgt):
"""
root_dir: directory to the mocked dataset
"""
temp_dataset_dir = os.path.join(root_dir, f"IWSLT2016/2016-01/texts/{src}/{tgt}/{src}-{tgt}/")
os.makedirs(temp_dataset_dir, exist_ok=True)

seed = 1
mocked_data = defaultdict(lambda: defaultdict(list))
valid_set = "tst2013"
test_set = "tst2014"

files_for_split, _ = _generate_iwslt_files_for_lang_and_split(16, src, tgt, valid_set, test_set)
src_file = files_for_split[src][split]
tgt_file = files_for_split[tgt][split]
for file_name in (src_file, tgt_file):
txt_file = os.path.join(temp_dataset_dir, file_name)
with open(txt_file, "w") as f:
# Get file extension (i.e., the language) without the . prefix (.en -> en)
lang = os.path.splitext(file_name)[1][1:]
for i in range(5):
rand_string = " ".join(
random.choice(string.ascii_letters) for i in range(seed)
)
dataset_line = f"{rand_string} {rand_string}\n"
# append line to correct dataset split
mocked_data[split][lang].append(dataset_line)
f.write(f'{rand_string} {rand_string}\n')
seed += 1

return list(zip(mocked_data[split][src], mocked_data[split][tgt]))
Comment on lines +15 to +45
Copy link
Contributor

@parmeet parmeet Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we are not creating a download archive 2016-01.tgz like we are doing for other datasets?

Edit: I think it quite important to start from the download archive, otherwise we can get into hard to find bugs specially when the compression pattern is complex like we have in here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we are not creating a download archive 2016-01.tgz like we are doing for other datasets?

Edit: I think it quite important to start from the download archive, otherwise we can get into hard to find bugs specially when the compression pattern is complex like we have in here.

@erip just wanted to check if you do plan to follow-up on this as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can follow up on this. It will take a lot more thought since, as you mention, the clean up is quite involved. That said, I think it should be doable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can follow up on this. It will take a lot more thought since, as you mention, the clean up is quite involved. That said, I think it should be doable.

Sure, thanks @erip!



class TestIWSLT2016(TempDirMixin, TorchtextTestCase):
root_dir = None
patcher = None

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.root_dir = cls.get_base_temp_dir()
cls.patcher = patch(
"torchdata.datapipes.iter.util.cacheholder.OnDiskCacheHolderIterDataPipe._cache_check_fn", return_value=True
)
cls.patcher.start()

@classmethod
def tearDownClass(cls):
cls.patcher.stop()
super().tearDownClass()

@parameterized.expand([("train", "de", "en"), ("valid", "de", "en")])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IWSLT2016 also consist of test split, so ideally we should also include it in testing.

Copy link
Contributor Author

@erip erip Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, yes. I had made a change and forgot to re-incorporate the test split. I can cut a PR to fix in the morning.

def test_iwslt2016(self, split, src, tgt):
expected_samples = _get_mock_dataset(self.root_dir, split, src, tgt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any specific reason why we don't want to generate all the mocked data within the setUpClass method and store it in self.samples like we do in the other tests?


dataset = IWSLT2016(root=self.root_dir, split=split)

samples = list(dataset)

for sample, expected_sample in zip_equal(samples, expected_samples):
Comment on lines +68 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: can we organize sampes and expected_samples similar to what we do in the SST2 PR for consistency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it's not quite as straightforward unless we want to hardcode the language pairs in setUpClass. Otherwise there's no good way to parameterize them. This is required because the expected file name for caching is a function of the (src_lang, tgt_lang, split) which is somewhat unique.

Copy link
Contributor Author

@erip erip Feb 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we could do is make samples a function which accepts *args and each dataset could handle them the same way using this pattern... The order matters here though because _get_mock_dataset needs to create the temp dir before IWSLT2016 reads it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it's not quite as straightforward unless we want to hardcode the language pairs in setUpClass. Otherwise there's no good way to parameterize them. This is required because the expected file name for caching is a function of the (src_lang, tgt_lang, split) which is somewhat unique.

Gotcha, I think it's okay to keep your current implementation. I didn't catch the fact that the ordering mattered here.

self.assertEqual(sample, expected_sample)

@parameterized.expand(["train", "valid"])
def test_iwslt2016_split_argument(self, split):
dataset1 = IWSLT2016(root=self.root_dir, split=split)
(dataset2,) = IWSLT2016(root=self.root_dir, split=(split,))

for d1, d2 in zip_equal(dataset1, dataset2):
self.assertEqual(d1, d2)