diff --git a/test/datasets/test_amazonreviewpolarity.py b/test/datasets/test_amazonreviewpolarity.py index 1cef5b550f..11c95ae785 100644 --- a/test/datasets/test_amazonreviewpolarity.py +++ b/test/datasets/test_amazonreviewpolarity.py @@ -8,7 +8,7 @@ from parameterized import parameterized from torchtext.datasets.amazonreviewpolarity import AmazonReviewPolarity -from ..common.case_utils import TempDirMixin +from ..common.case_utils import TempDirMixin, zip_equal from ..common.torchtext_test_case import TorchtextTestCase @@ -55,28 +55,29 @@ def setUpClass(cls): super().setUpClass() cls.root_dir = cls.get_base_temp_dir() cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch( + "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True + ) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() @parameterized.expand(["train", "test"]) def test_amazon_review_polarity(self, split): - with patch( - "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True - ): - dataset = AmazonReviewPolarity(root=self.root_dir, split=split) - n_iter = 0 - for i, (label, text) in enumerate(dataset): - expected_sample = self.samples[split][i] - assert label == expected_sample[0] - assert text == expected_sample[1] - n_iter += 1 - assert n_iter == len(self.samples[split]) + dataset = AmazonReviewPolarity(root=self.root_dir, split=split) - @parameterized.expand([("train", ("train",)), ("test", ("test",))]) - def test_amazon_review_polarity_split_argument(self, split1, split2): - with patch( - "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True - ): - dataset1 = AmazonReviewPolarity(root=self.root_dir, split=split1) - (dataset2,) = AmazonReviewPolarity(root=self.root_dir, split=split2) + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "test"]) + def test_amazon_review_polarity_split_argument(self, split): + dataset1 = AmazonReviewPolarity(root=self.root_dir, split=split) + (dataset2,) = AmazonReviewPolarity(root=self.root_dir, split=(split,)) - for d1, d2 in zip(dataset1, dataset2): - self.assertEqual(d1, d2) + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2)