55from collections import defaultdict
66from unittest .mock import patch
77
8- from parameterized import parameterized
8+ from torchtext . datasets . amazonreviewfull import AmazonReviewFull
99from torchtext .datasets .amazonreviewpolarity import AmazonReviewPolarity
1010
1111from ..common .case_utils import TempDirMixin , zip_equal
12+ from ..common .parameterized_utils import nested_params
1213from ..common .torchtext_test_case import TorchtextTestCase
1314
1415
15- def _get_mock_dataset (root_dir ):
16+ def _get_mock_dataset (root_dir , base_dir_name ):
1617 """
1718 root_dir: directory to the mocked dataset
19+ base_dir_name: AmazonReviewFull or AmazonReviewPolarity
1820 """
19- base_dir = os .path .join (root_dir , "AmazonReviewPolarity" )
21+ base_dir = os .path .join (root_dir , base_dir_name )
2022 temp_dataset_dir = os .path .join (base_dir , "temp_dataset_dir" )
2123 os .makedirs (temp_dataset_dir , exist_ok = True )
2224
@@ -26,7 +28,10 @@ def _get_mock_dataset(root_dir):
2628 txt_file = os .path .join (temp_dataset_dir , file_name )
2729 with open (txt_file , "w" ) as f :
2830 for i in range (5 ):
29- label = seed % 2 + 1
31+ if base_dir_name == AmazonReviewFull .__name__ :
32+ label = seed % 5 + 1
33+ else :
34+ label = seed % 2 + 1
3035 rand_string = " " .join (
3136 random .choice (string .ascii_letters ) for i in range (seed )
3237 )
@@ -36,25 +41,27 @@ def _get_mock_dataset(root_dir):
3641 f .write (f'"{ label } ","{ rand_string } ","{ rand_string } "\n ' )
3742 seed += 1
3843
39- compressed_dataset_path = os .path .join (
40- base_dir , "amazon_review_polarity_csv.tar.gz"
41- )
44+ if base_dir_name == AmazonReviewFull .__name__ :
45+ archive_file_name = "amazon_review_full_csv"
46+ else :
47+ archive_file_name = "amazon_review_polarity_csv"
48+
49+ compressed_dataset_path = os .path .join (base_dir , f"{ archive_file_name } .tar.gz" )
4250 # create tar file from dataset folder
4351 with tarfile .open (compressed_dataset_path , "w:gz" ) as tar :
44- tar .add (temp_dataset_dir , arcname = "amazon_review_polarity_csv" )
52+ tar .add (temp_dataset_dir , arcname = archive_file_name )
4553
4654 return mocked_data
4755
4856
49- class TestAmazonReviewPolarity (TempDirMixin , TorchtextTestCase ):
57+ class TestAmazonReviews (TempDirMixin , TorchtextTestCase ):
5058 root_dir = None
5159 samples = []
5260
5361 @classmethod
5462 def setUpClass (cls ):
5563 super ().setUpClass ()
5664 cls .root_dir = cls .get_base_temp_dir ()
57- cls .samples = _get_mock_dataset (cls .root_dir )
5865 cls .patcher = patch (
5966 "torchdata.datapipes.iter.util.cacheholder._hash_check" , return_value = True
6067 )
@@ -65,19 +72,24 @@ def tearDownClass(cls):
6572 cls .patcher .stop ()
6673 super ().tearDownClass ()
6774
68- @parameterized .expand (["train" , "test" ])
69- def test_amazon_review_polarity (self , split ):
70- dataset = AmazonReviewPolarity (root = self .root_dir , split = split )
71-
75+ @nested_params ([AmazonReviewFull , AmazonReviewPolarity ], ["train" , "test" ])
76+ def test_amazon_reviews (self , amazon_review_dataset , split ):
77+ expected_samples = _get_mock_dataset (
78+ self .root_dir , amazon_review_dataset .__name__
79+ )[split ]
80+ dataset = amazon_review_dataset (root = self .root_dir , split = split )
7281 samples = list (dataset )
73- expected_samples = self . samples [ split ]
82+
7483 for sample , expected_sample in zip_equal (samples , expected_samples ):
7584 self .assertEqual (sample , expected_sample )
7685
77- @parameterized .expand (["train" , "test" ])
78- def test_amazon_review_polarity_split_argument (self , split ):
79- dataset1 = AmazonReviewPolarity (root = self .root_dir , split = split )
80- (dataset2 ,) = AmazonReviewPolarity (root = self .root_dir , split = (split ,))
86+ @nested_params ([AmazonReviewFull , AmazonReviewPolarity ], ["train" , "test" ])
87+ def test_amazon_reviews_split_argument (self , amazon_review_dataset , split ):
88+ # call `_get_mock_dataset` to create mock dataset files
89+ _ = _get_mock_dataset (self .root_dir , amazon_review_dataset .__name__ )
90+
91+ dataset1 = amazon_review_dataset (root = self .root_dir , split = split )
92+ (dataset2 ,) = amazon_review_dataset (root = self .root_dir , split = (split ,))
8193
8294 for d1 , d2 in zip_equal (dataset1 , dataset2 ):
8395 self .assertEqual (d1 , d2 )
0 commit comments