Skip to content

Commit b78d98b

Browse files
authored
add example for v2 wrapping for custom datasets (#7514)
1 parent fc377d0 commit b78d98b

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

gallery/plot_datapoints.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
torchvision.disable_beta_transforms_warning()
2121

2222
from torchvision import datapoints
23+
from torchvision.transforms.v2 import functional as F
2324

2425

2526
########################################################################################################################
@@ -93,6 +94,68 @@
9394
# built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you
9495
# also don't have to wrap manually.
9596
#
97+
# If you have a custom dataset, for example the ``PennFudanDataset`` from
98+
# `this tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_, you have two options:
99+
#
100+
# 1. Perform the wrapping inside ``__getitem__``:
101+
102+
class PennFudanDataset(torch.utils.data.Dataset):
103+
...
104+
105+
def __getitem__(self, item):
106+
...
107+
108+
target["boxes"] = datapoints.BoundingBox(
109+
boxes,
110+
format=datapoints.BoundingBoxFormat.XYXY,
111+
spatial_size=F.get_spatial_size(img),
112+
)
113+
target["labels"] = labels
114+
target["masks"] = datapoints.Mask(masks)
115+
116+
...
117+
118+
if self.transforms is not None:
119+
img, target = self.transforms(img, target)
120+
121+
...
122+
123+
########################################################################################################################
124+
# 2. Perform the wrapping inside a custom transformation at the beginning of your pipeline:
125+
126+
127+
class WrapPennFudanDataset:
128+
def __call__(self, img, target):
129+
target["boxes"] = datapoints.BoundingBox(
130+
target["boxes"],
131+
format=datapoints.BoundingBoxFormat.XYXY,
132+
spatial_size=F.get_spatial_size(img),
133+
)
134+
target["masks"] = datapoints.Mask(target["masks"])
135+
return img, target
136+
137+
138+
...
139+
140+
141+
def get_transform(train):
142+
transforms = []
143+
transforms.append(WrapPennFudanDataset())
144+
transforms.append(T.PILToTensor())
145+
...
146+
147+
########################################################################################################################
148+
# .. note::
149+
#
150+
# If both :class:`~torchvision.datapoints.BoundingBox`'es and :class:`~torchvision.datapoints.Mask`'s are included in
151+
# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
152+
# at least not wrapping the obsolete parts, can lead to a significant performance boost.
153+
#
154+
# For example, if you are using the ``PennFudanDataset`` for object detection, not wrapping the masks avoids
155+
# transforming them over and over again in the pipeline just to ultimately ignoring them. In general, it would be
156+
# even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are
157+
# generated from the masks.
158+
#
96159
# How do the datapoints behave inside a computation?
97160
# --------------------------------------------------
98161
#
@@ -101,6 +164,7 @@
101164
# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the
102165
# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below):
103166

167+
104168
assert isinstance(image, datapoints.Image)
105169

106170
new_image = image + 0

torchvision/datapoints/_dataset_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def __init__(self, dataset, target_keys):
124124
if not isinstance(dataset, datasets.VisionDataset):
125125
raise TypeError(
126126
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
127-
f"but got a '{dataset_cls.__name__}' instead."
127+
f"but got a '{dataset_cls.__name__}' instead.\n"
128+
f"For an example of how to perform the wrapping for custom datasets, see\n\n"
129+
"https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
128130
)
129131

130132
for cls in dataset_cls.mro():

0 commit comments

Comments
 (0)