diff --git a/test/torchtext_unittest/datasets/common.py b/test/torchtext_unittest/datasets/common.py index 10071bd73b..0fa3ae2b00 100644 --- a/test/torchtext_unittest/datasets/common.py +++ b/test/torchtext_unittest/datasets/common.py @@ -1,7 +1,7 @@ import pickle from parameterized import parameterized -from torch.utils.data.graph import traverse +from torch.utils.data.graph import traverse_dps from torch.utils.data.graph_settings import get_all_graph_pipes from torchdata.dataloader2.linter import _check_shuffle_before_sharding from torchdata.datapipes.iter import Shuffler, ShardingFilter @@ -37,7 +37,7 @@ def test_shuffle_shard_wrapper(self, dataset_fn): for dp_split in dp: _check_shuffle_before_sharding(dp_split) - dp_graph = get_all_graph_pipes(traverse(dp_split)) + dp_graph = get_all_graph_pipes(traverse_dps(dp_split)) for annotation_dp_type in [Shuffler, ShardingFilter]: if not any(isinstance(dp, annotation_dp_type) for dp in dp_graph): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")