|
20 | 20 | torchvision.disable_beta_transforms_warning() |
21 | 21 |
|
22 | 22 | from torchvision import datapoints |
| 23 | +from torchvision.transforms.v2 import functional as F |
23 | 24 |
|
24 | 25 |
|
25 | 26 | ######################################################################################################################## |
|
93 | 94 | # built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you |
94 | 95 | # also don't have to wrap manually. |
95 | 96 | # |
| 97 | +# If you have a custom dataset, for example the ``PennFudanDataset`` from |
| 98 | +# `this tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_, you have two options: |
| 99 | +# |
| 100 | +# 1. Perform the wrapping inside ``__getitem__``: |
| 101 | + |
| 102 | +class PennFudanDataset(torch.utils.data.Dataset): |
| 103 | + ... |
| 104 | + |
| 105 | + def __getitem__(self, item): |
| 106 | + ... |
| 107 | + |
| 108 | + target["boxes"] = datapoints.BoundingBox( |
| 109 | + boxes, |
| 110 | + format=datapoints.BoundingBoxFormat.XYXY, |
| 111 | + spatial_size=F.get_spatial_size(img), |
| 112 | + ) |
| 113 | + target["labels"] = labels |
| 114 | + target["masks"] = datapoints.Mask(masks) |
| 115 | + |
| 116 | + ... |
| 117 | + |
| 118 | + if self.transforms is not None: |
| 119 | + img, target = self.transforms(img, target) |
| 120 | + |
| 121 | + ... |
| 122 | + |
| 123 | +######################################################################################################################## |
| 124 | +# 2. Perform the wrapping inside a custom transformation at the beginning of your pipeline: |
| 125 | + |
| 126 | + |
| 127 | +class WrapPennFudanDataset: |
| 128 | + def __call__(self, img, target): |
| 129 | + target["boxes"] = datapoints.BoundingBox( |
| 130 | + target["boxes"], |
| 131 | + format=datapoints.BoundingBoxFormat.XYXY, |
| 132 | + spatial_size=F.get_spatial_size(img), |
| 133 | + ) |
| 134 | + target["masks"] = datapoints.Mask(target["masks"]) |
| 135 | + return img, target |
| 136 | + |
| 137 | + |
| 138 | +... |
| 139 | + |
| 140 | + |
| 141 | +def get_transform(train): |
| 142 | + transforms = [] |
| 143 | + transforms.append(WrapPennFudanDataset()) |
| 144 | + transforms.append(T.PILToTensor()) |
| 145 | + ... |
| 146 | + |
| 147 | +######################################################################################################################## |
| 148 | +# .. note:: |
| 149 | +# |
| 150 | +# If both :class:`~torchvision.datapoints.BoundingBox`'es and :class:`~torchvision.datapoints.Mask`'s are included in |
| 151 | +# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or |
| 152 | +# at least not wrapping the obsolete parts, can lead to a significant performance boost. |
| 153 | +# |
| 154 | +# For example, if you are using the ``PennFudanDataset`` for object detection, not wrapping the masks avoids |
| 155 | +# transforming them over and over again in the pipeline just to ultimately ignoring them. In general, it would be |
| 156 | +# even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are |
| 157 | +# generated from the masks. |
| 158 | +# |
96 | 159 | # How do the datapoints behave inside a computation? |
97 | 160 | # -------------------------------------------------- |
98 | 161 | # |
|
101 | 164 | # Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the |
102 | 165 | # datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below): |
103 | 166 |
|
| 167 | + |
104 | 168 | assert isinstance(image, datapoints.Image) |
105 | 169 |
|
106 | 170 | new_image = image + 0 |
|
0 commit comments