Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .circleci/config.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ commands:
file_or_dir:
type: string
steps:
- run:
name: Install test utilities
command: pip install --progress-bar=off pytest pytest-mock
- pip_install:
args: pytest pytest-mock
descr: Install test utilities
- run:
name: Run tests
command: pytest --junitxml=test-results/junit.xml -v --durations 20 <<parameters.file_or_dir>>
Expand Down
35 changes: 30 additions & 5 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import functools
import io
import pickle
import os
from pathlib import Path

import pytest
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data.dataloader_experimental import DataLoader2
from torch.utils.data.graph import traverse
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter
Expand Down Expand Up @@ -116,13 +117,37 @@ def test_transformable(self, test_home, dataset_mock, config):

next(iter(dataset.map(transforms.Identity())))

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
@pytest.mark.parametrize("parallelism_mode", [None, "mp", "thread"])
@parametrize_dataset_mocks({name: mock for name, mock in DATASET_MOCKS.items() if name not in {"qmnist", "voc"}})
def test_pipeline(self, test_home, dataset_mock, config, parallelism_mode):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

pickle.dumps(dataset)
transform = transforms.Compose(transforms.DecodeImage(), transforms.Resize([3, 3]))

# TODO: add a .collate() here as soon as https://github.com/pytorch/vision/pull/5233 is resolved
dp = dataset.map(transform).batch(2, drop_last=parallelism_mode == "thread")

if parallelism_mode:
# Maybe we can make this is a static method of the data_loader?
try:
num_workers = len(os.sched_getaffinity(0))
except Exception:
num_workers = os.cpu_count() or 1
else:
num_workers = 0

dl = DataLoader2(
dp,
batch_size=None,
shuffle=True,
num_workers=num_workers,
parallelism_mode=parallelism_mode,
timeout=5 if num_workers > 0 else 0,
)

for _ in dl:
pass

# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
Expand Down