From e0a0d368c14628c69232389dcdca38a1e69c18a8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 14:23:10 +0100 Subject: [PATCH 1/7] allow subclasses in dataset wrappers --- torchvision/prototype/datapoints/_dataset_wrapper.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 1159261054b..0fc4710d2dd 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -39,7 +39,16 @@ def decorator(wrapper_factory): class VisionDatasetDatapointWrapper(Dataset): def __init__(self, dataset): dataset_cls = type(dataset) + # We test for exact dataset class matches first. If we don't find one, we check if the dataset subclasses any + # of the known ones. wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls) + if wrapper_factory is None: + with contextlib.suppress(StopIteration): + wrapper_factory = next( + wrapper_factory + for dataset_supercls_candidate, wrapper_factory in WRAPPER_FACTORIES.items() + if issubclass(dataset_cls, dataset_supercls_candidate) + ) if wrapper_factory is None: # TODO: If we have documentation on how to do that, put a link in the error message. msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." From 143217b174ab165a11e0e17a626b94e22dfb445a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 20:50:25 +0100 Subject: [PATCH 2/7] support CocoCaptions --- torchvision/prototype/datapoints/_dataset_wrapper.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 9be68899e8e..ae5f3ca6e3d 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -107,6 +107,10 @@ def identity(item): return item +def identity_wrapper_factory(dataset): + return identity + + def pil_image_to_mask(pil_image): return datapoints.Mask(pil_image) @@ -134,7 +138,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): def classification_wrapper_factory(dataset): - return identity + return identity_wrapper_factory(dataset) for dataset_cls in [ @@ -240,6 +244,9 @@ def wrapper(sample): return wrapper +WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory) + + VOC_DETECTION_CATEGORIES = [ "__background__", "aeroplane", From 0741ce0cf3acc4082167ac615cecadba222338c0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 21:04:12 +0100 Subject: [PATCH 3/7] refactor to walk MRO --- .../prototype/datapoints/_dataset_wrapper.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index ae5f3ca6e3d..eb751fdceee 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -3,6 +3,7 @@ from __future__ import annotations import contextlib +import itertools from collections import defaultdict import torch @@ -39,17 +40,13 @@ def decorator(wrapper_factory): class VisionDatasetDatapointWrapper(Dataset): def __init__(self, dataset): dataset_cls = type(dataset) - # We test for exact dataset class matches first. If we don't find one, we check if the dataset subclasses any - # of the known ones. - wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls) - if wrapper_factory is None: - with contextlib.suppress(StopIteration): - wrapper_factory = next( - wrapper_factory - for dataset_supercls_candidate, wrapper_factory in WRAPPER_FACTORIES.items() - if issubclass(dataset_cls, dataset_supercls_candidate) - ) - if wrapper_factory is None: + for cls in itertools.takewhile( + lambda dataset_cls: dataset_cls is not datasets.VisionDataset, dataset_cls.mro() + ): + if cls in WRAPPER_FACTORIES: + wrapper_factory = WRAPPER_FACTORIES[cls] + break + else: # TODO: If we have documentation on how to do that, put a link in the error message. msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." if dataset_cls in datasets.__dict__.values(): From 01011d3904ab6cb91c450a5e3bc9839632cbadd6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 23:54:36 +0100 Subject: [PATCH 4/7] add test for subclass wrapping --- test/test_prototype_datapoints.py | 12 ++++++++++++ torchvision/prototype/datapoints/_dataset_wrapper.py | 1 + 2 files changed, 13 insertions(+) diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index 4663cdac3da..a96012f2f3d 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -2,6 +2,8 @@ import torch from PIL import Image + +from torchvision import datasets from torchvision.prototype import datapoints @@ -159,3 +161,13 @@ def test_bbox_instance(data, format): if isinstance(format, str): format = datapoints.BoundingBoxFormat.from_str(format.upper()) assert bboxes.format == format + + +def test_dataset_wrapper_subclass(): + class MyFakeData(datasets.FakeData): + pass + + dataset = MyFakeData() + wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset) + + assert wrapped_dataset[0] is not None diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index db5efe2306e..d3bcc1b9d05 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -151,6 +151,7 @@ def classification_wrapper_factory(dataset): datasets.GTSRB, datasets.DatasetFolder, datasets.ImageFolder, + datasets.FakeData, ]: WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory) From 0766a68faccab0a8b0fe116a084c926f97ba1bca Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 09:55:21 +0100 Subject: [PATCH 5/7] simplify subclass detection --- test/test_prototype_datapoints.py | 18 +++++++---- .../prototype/datapoints/_dataset_wrapper.py | 30 +++++++++++-------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index a96012f2f3d..37c05a2a8c1 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -163,11 +163,17 @@ def test_bbox_instance(data, format): assert bboxes.format == format -def test_dataset_wrapper_subclass(): - class MyFakeData(datasets.FakeData): - pass +class TestDatasetWrapper: + def test_unknown_type(self): + unknown_object = object() + with pytest.raises(TypeError, match=type(unknown_object).__name__): + datapoints.wrap_dataset_for_transforms_v2(unknown_object) - dataset = MyFakeData() - wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset) + def test_subclass(self): + class MyFakeData(datasets.FakeData): + pass - assert wrapped_dataset[0] is not None + dataset = MyFakeData() + wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset) + + assert wrapped_dataset[0] is not None diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index d3bcc1b9d05..b0c14112ada 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -3,7 +3,6 @@ from __future__ import annotations import contextlib -import itertools from collections import defaultdict import torch @@ -40,21 +39,26 @@ def decorator(wrapper_factory): class VisionDatasetDatapointWrapper(Dataset): def __init__(self, dataset): dataset_cls = type(dataset) - for cls in itertools.takewhile( - lambda dataset_cls: dataset_cls is not datasets.VisionDataset, dataset_cls.mro() - ): + + if not isinstance(dataset, datasets.VisionDataset): + raise TypeError( + f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, " + f"but got a '{dataset_cls.__name__}' instead." + ) + + for cls in dataset_cls.mro(): if cls in WRAPPER_FACTORIES: wrapper_factory = WRAPPER_FACTORIES[cls] break - else: - # TODO: If we have documentation on how to do that, put a link in the error message. - msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." - if dataset_cls in datasets.__dict__.values(): - msg = ( - f"{msg} If an automated wrapper for this dataset would be useful for you, " - f"please open an issue at https://github.com/pytorch/vision/issues." - ) - raise TypeError(msg) + elif cls is datasets.VisionDataset: + # TODO: If we have documentation on how to do that, put a link in the error message. + msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." + if dataset_cls in datasets.__dict__.values(): + msg = ( + f"{msg} If an automated wrapper for this dataset would be useful for you, " + f"please open an issue at https://github.com/pytorch/vision/issues." + ) + raise TypeError(msg) self._dataset = dataset self._wrapper = wrapper_factory(dataset) From 782f2fe498f620e2077ad46bd6c69b78c8733733 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 10:26:15 +0100 Subject: [PATCH 6/7] address comments --- test/test_prototype_datapoints.py | 32 +++++++++++++++++-- .../prototype/datapoints/_dataset_wrapper.py | 3 +- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index 37c05a2a8c1..c2cc0986b71 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -1,3 +1,5 @@ +import re + import pytest import torch @@ -166,14 +168,38 @@ def test_bbox_instance(data, format): class TestDatasetWrapper: def test_unknown_type(self): unknown_object = object() - with pytest.raises(TypeError, match=type(unknown_object).__name__): + with pytest.raises( + TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`") + ): datapoints.wrap_dataset_for_transforms_v2(unknown_object) - def test_subclass(self): + def test_unknown_dataset(self): + class MyVisionDataset(datasets.VisionDataset): + pass + + dataset = MyVisionDataset("root") + + with pytest.raises(TypeError, match="No wrapper exist"): + datapoints.wrap_dataset_for_transforms_v2(dataset) + + def test_missing_wrapper(self): + dataset = datasets.FakeData() + + with pytest.raises(TypeError, match="please open an issue"): + datapoints.wrap_dataset_for_transforms_v2(dataset) + + def test_subclass(self, mocker): + sentinel = object() + mocker.patch.dict( + datapoints._dataset_wrapper.WRAPPER_FACTORIES, + clear=False, + values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel}, + ) + class MyFakeData(datasets.FakeData): pass dataset = MyFakeData() wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset) - assert wrapped_dataset[0] is not None + assert wrapped_dataset[0] is sentinel diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index b0c14112ada..74f83095177 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -52,7 +52,7 @@ def __init__(self, dataset): break elif cls is datasets.VisionDataset: # TODO: If we have documentation on how to do that, put a link in the error message. - msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." + msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself." if dataset_cls in datasets.__dict__.values(): msg = ( f"{msg} If an automated wrapper for this dataset would be useful for you, " @@ -155,7 +155,6 @@ def classification_wrapper_factory(dataset): datasets.GTSRB, datasets.DatasetFolder, datasets.ImageFolder, - datasets.FakeData, ]: WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory) From 21a1204238ff3fecae4f14978a9533736e81c8fd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 11:04:59 +0100 Subject: [PATCH 7/7] fix tests --- test/datasets_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 598d4408b76..c02ffeb0d68 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -596,7 +596,7 @@ def test_transforms_v2_wrapper(self, config): wrapped_sample = wrapped_dataset[0] assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) except TypeError as error: - if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"): + if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"): return raise error except RuntimeError as error: