diff --git a/test/datasets/common.py b/test/datasets/common.py index 81edc565bf..f9ead3d3b3 100644 --- a/test/datasets/common.py +++ b/test/datasets/common.py @@ -10,8 +10,8 @@ class TestShuffleShardDatasetWrapper(TorchtextTestCase): # Note that for order i.e shuffle before sharding, TorchData will provide linter warning # Modify this test when linter warning is available - @parameterized.expand(list(DATASETS.items())) - def test_shuffle_shard_wrapper(self, dataset_name, dataset_fn): + @parameterized.expand([(f,) for f in DATASETS.values()]) + def test_shuffle_shard_wrapper(self, dataset_fn): dp = dataset_fn() if type(dp) == tuple: dp = list(dp)