From 270f4f040f5e9c2904a67d8e4f06e74bbb7205dd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Aug 2023 15:30:57 +0100 Subject: [PATCH 01/11] Use syntax for gallery examples --- gallery/plot_datapoints.py | 18 ++++---- gallery/plot_optical_flow.py | 16 +++---- gallery/plot_repurposing_annotations.py | 20 ++++----- gallery/plot_scripted_tensor_transforms.py | 12 ++--- gallery/plot_transforms.py | 52 +++++++++++----------- gallery/plot_transforms_v2.py | 8 ++-- gallery/plot_transforms_v2_e2e.py | 12 ++--- gallery/plot_video_api.py | 34 +++++++------- gallery/plot_visualization_utils.py | 48 ++++++++++---------- 9 files changed, 110 insertions(+), 110 deletions(-) diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index fef282ae091..57e29bd86eb 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -23,7 +23,7 @@ from torchvision.transforms.v2 import functional as F -######################################################################################################################## +# %% # What are datapoints? # -------------------- # @@ -36,7 +36,7 @@ assert image.data_ptr() == tensor.data_ptr() -######################################################################################################################## +# %% # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # for the input data. # @@ -59,7 +59,7 @@ print(image) -######################################################################################################################## +# %% # Similar to other PyTorch creations ops, the constructor also takes the ``dtype``, ``device``, and ``requires_grad`` # parameters. @@ -67,14 +67,14 @@ print(float_image) -######################################################################################################################## +# %% # In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` also take a # :class:`PIL.Image.Image` directly: image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg")) print(image.shape, image.dtype) -######################################################################################################################## +# %% # In general, the datapoints can also store additional metadata that complements the underlying tensor. For example, # :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the # corresponding image alongside the actual values: @@ -85,7 +85,7 @@ print(bounding_box) -######################################################################################################################## +# %% # Do I have to wrap the output of the datasets myself? # ---------------------------------------------------- # @@ -120,7 +120,7 @@ def __getitem__(self, item): ... -######################################################################################################################## +# %% # 2. Perform the wrapping inside a custom transformation at the beginning of your pipeline: @@ -144,7 +144,7 @@ def get_transform(train): transforms.append(T.PILToTensor()) ... -######################################################################################################################## +# %% # .. note:: # # If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in @@ -171,7 +171,7 @@ def get_transform(train): assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) -######################################################################################################################## +# %% # .. note:: # # This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you diff --git a/gallery/plot_optical_flow.py b/gallery/plot_optical_flow.py index 835ce330180..499f8c66398 100644 --- a/gallery/plot_optical_flow.py +++ b/gallery/plot_optical_flow.py @@ -42,7 +42,7 @@ def plot(imgs, **imshow_kwargs): plt.tight_layout() -################################### +# %% # Reading Videos Using Torchvision # -------------------------------- # We will first read a video using :func:`~torchvision.io.read_video`. @@ -62,7 +62,7 @@ def plot(imgs, **imshow_kwargs): video_path = Path(tempfile.mkdtemp()) / "basketball.mp4" _ = urlretrieve(video_url, video_path) -######################### +# %% # :func:`~torchvision.io.read_video` returns the video frames, audio frames and # the metadata associated with the video. In our case, we only need the video # frames. @@ -79,7 +79,7 @@ def plot(imgs, **imshow_kwargs): plot(img1_batch) -######################### +# %% # The RAFT model accepts RGB images. We first get the frames from # :func:`~torchvision.io.read_video` and resize them to ensure their dimensions # are divisible by 8. Note that we explicitly use ``antialias=False``, because @@ -104,7 +104,7 @@ def preprocess(img1_batch, img2_batch): print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}") -#################################### +# %% # Estimating Optical flow using RAFT # ---------------------------------- # We will use our RAFT implementation from @@ -125,7 +125,7 @@ def preprocess(img1_batch, img2_batch): print(f"type = {type(list_of_flows)}") print(f"length = {len(list_of_flows)} = number of iterations of the model") -#################################### +# %% # The RAFT model outputs lists of predicted flows where each entry is a # (N, 2, H, W) batch of predicted flows that corresponds to a given "iteration" # in the model. For more details on the iterative nature of the model, please @@ -144,7 +144,7 @@ def preprocess(img1_batch, img2_batch): print(f"min = {predicted_flows.min()}, max = {predicted_flows.max()}") -#################################### +# %% # Visualizing predicted flows # --------------------------- # Torchvision provides the :func:`~torchvision.utils.flow_to_image` utlity to @@ -166,7 +166,7 @@ def preprocess(img1_batch, img2_batch): grid = [[img1, flow_img] for (img1, flow_img) in zip(img1_batch, flow_imgs)] plot(grid) -#################################### +# %% # Bonus: Creating GIFs of predicted flows # --------------------------------------- # In the example above we have only shown the predicted flows of 2 pairs of @@ -187,7 +187,7 @@ def preprocess(img1_batch, img2_batch): # output_folder = "/tmp/" # Update this to the folder of your choice # write_jpeg(flow_img, output_folder + f"predicted_flow_{i}.jpg") -#################################### +# %% # Once the .jpg flow images are saved, you can convert them into a video or a # GIF using ffmpeg with e.g.: # diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py index 7bb68617a17..99f75f03fc1 100644 --- a/gallery/plot_repurposing_annotations.py +++ b/gallery/plot_repurposing_annotations.py @@ -36,7 +36,7 @@ def show(imgs): axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) -#################################### +# %% # Masks # ----- # In tasks like instance and panoptic segmentation, masks are commonly defined, and are defined by this package, @@ -53,7 +53,7 @@ def show(imgs): # A nice property of masks is that they can be easily repurposed to be used in methods to solve a variety of object # localization tasks. -#################################### +# %% # Converting Masks to Bounding Boxes # ----------------------------------------------- # For example, the :func:`~torchvision.ops.masks_to_boxes` operation can be used to @@ -70,7 +70,7 @@ def show(imgs): mask = read_image(mask_path) -######################### +# %% # Here the masks are represented as a PNG Image, with floating point values. # Each pixel is encoded as different colors, with 0 being background. # Notice that the spatial dimensions of image and mask match. @@ -79,7 +79,7 @@ def show(imgs): print(img.size()) print(mask) -############################ +# %% # We get the unique colors, as these would be the object ids. obj_ids = torch.unique(mask) @@ -91,7 +91,7 @@ def show(imgs): # Note that this snippet would work as well if the masks were float values instead of ints. masks = mask == obj_ids[:, None, None] -######################## +# %% # Now the masks are a boolean tensor. # The first dimension in this case 3 and denotes the number of instances: there are 3 people in the image. # The other two dimensions are height and width, which are equal to the dimensions of the image. @@ -101,7 +101,7 @@ def show(imgs): print(masks.size()) print(masks) -#################################### +# %% # Let us visualize an image and plot its corresponding segmentation masks. # We will use the :func:`~torchvision.utils.draw_segmentation_masks` to draw the segmentation masks. @@ -113,7 +113,7 @@ def show(imgs): show(drawn_masks) -#################################### +# %% # To convert the boolean masks into bounding boxes. # We will use the :func:`~torchvision.ops.masks_to_boxes` from the torchvision.ops module # It returns the boxes in ``(xmin, ymin, xmax, ymax)`` format. @@ -124,7 +124,7 @@ def show(imgs): print(boxes.size()) print(boxes) -#################################### +# %% # As the shape denotes, there are 3 boxes and in ``(xmin, ymin, xmax, ymax)`` format. # These can be visualized very easily with :func:`~torchvision.utils.draw_bounding_boxes` utility # provided in :ref:`torchvision.utils `. @@ -134,7 +134,7 @@ def show(imgs): drawn_boxes = draw_bounding_boxes(img, boxes, colors="red") show(drawn_boxes) -################################### +# %% # These boxes can now directly be used by detection models in torchvision. # Here is demo with a Faster R-CNN model loaded from # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` @@ -153,7 +153,7 @@ def show(imgs): detection_outputs = model(img.unsqueeze(0), [target]) -#################################### +# %% # Converting Segmentation Dataset to Detection Dataset # ---------------------------------------------------- # diff --git a/gallery/plot_scripted_tensor_transforms.py b/gallery/plot_scripted_tensor_transforms.py index b0851217e50..e803da7799e 100644 --- a/gallery/plot_scripted_tensor_transforms.py +++ b/gallery/plot_scripted_tensor_transforms.py @@ -45,7 +45,7 @@ def show(imgs): axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) -#################################### +# %% # The :func:`~torchvision.io.read_image` function allows to read an image and # directly load it as a tensor @@ -53,7 +53,7 @@ def show(imgs): dog2 = read_image(str(Path('assets') / 'dog2.jpg')) show([dog1, dog2]) -#################################### +# %% # Transforming images on GPU # -------------------------- # Most transforms natively support tensors on top of PIL images (to visualize @@ -76,7 +76,7 @@ def show(imgs): transformed_dog2 = transforms(dog2) show([transformed_dog1, transformed_dog2]) -#################################### +# %% # Scriptable transforms for easier deployment via torchscript # ----------------------------------------------------------- # We now show how to combine image transformations and a model forward pass, @@ -103,7 +103,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return y_pred.argmax(dim=1) -#################################### +# %% # Now, let's define scripted and non-scripted instances of ``Predictor`` and # apply it on multiple tensor images of the same size @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: res = predictor(batch) res_scripted = scripted_predictor(batch) -#################################### +# %% # We can verify that the prediction of the scripted and non-scripted models are # the same: @@ -128,7 +128,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: assert pred == pred_scripted print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}") -#################################### +# %% # Since the model is scripted, it can be easily dumped on disk and re-used import tempfile diff --git a/gallery/plot_transforms.py b/gallery/plot_transforms.py index 2330dc0f967..ac6e50a397e 100644 --- a/gallery/plot_transforms.py +++ b/gallery/plot_transforms.py @@ -50,7 +50,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): plt.tight_layout() -#################################### +# %% # Geometric Transforms # -------------------- # Geometric image transformation refers to the process of altering the geometric properties of an image, @@ -65,7 +65,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)] plot(padded_imgs) -#################################### +# %% # Resize # ~~~~~~ # The :class:`~torchvision.transforms.Resize` transform @@ -74,7 +74,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)] plot(resized_imgs) -#################################### +# %% # CenterCrop # ~~~~~~~~~~ # The :class:`~torchvision.transforms.CenterCrop` transform @@ -83,7 +83,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)] plot(center_crops) -#################################### +# %% # FiveCrop # ~~~~~~~~ # The :class:`~torchvision.transforms.FiveCrop` transform @@ -92,7 +92,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): (top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img) plot([top_left, top_right, bottom_left, bottom_right, center]) -#################################### +# %% # RandomPerspective # ~~~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomPerspective` transform @@ -102,7 +102,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)] plot(perspective_imgs) -#################################### +# %% # RandomRotation # ~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomRotation` transform @@ -112,7 +112,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): rotated_imgs = [rotater(orig_img) for _ in range(4)] plot(rotated_imgs) -#################################### +# %% # RandomAffine # ~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomAffine` transform @@ -122,7 +122,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): affine_imgs = [affine_transfomer(orig_img) for _ in range(4)] plot(affine_imgs) -#################################### +# %% # ElasticTransform # ~~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.ElasticTransform` transform @@ -133,7 +133,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)] plot(transformed_imgs) -#################################### +# %% # RandomCrop # ~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomCrop` transform @@ -143,7 +143,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): crops = [cropper(orig_img) for _ in range(4)] plot(crops) -#################################### +# %% # RandomResizedCrop # ~~~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomResizedCrop` transform @@ -154,7 +154,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): resized_crops = [resize_cropper(orig_img) for _ in range(4)] plot(resized_crops) -#################################### +# %% # Photometric Transforms # ---------------------- # Photometric image transformation refers to the process of modifying the photometric properties of an image, @@ -174,7 +174,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): gray_img = T.Grayscale()(orig_img) plot([gray_img], cmap='gray') -#################################### +# %% # ColorJitter # ~~~~~~~~~~~ # The :class:`~torchvision.transforms.ColorJitter` transform @@ -183,7 +183,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): jitted_imgs = [jitter(orig_img) for _ in range(4)] plot(jitted_imgs) -#################################### +# %% # GaussianBlur # ~~~~~~~~~~~~ # The :class:`~torchvision.transforms.GaussianBlur` transform @@ -193,7 +193,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): blurred_imgs = [blurrer(orig_img) for _ in range(4)] plot(blurred_imgs) -#################################### +# %% # RandomInvert # ~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomInvert` transform @@ -203,7 +203,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): invertered_imgs = [inverter(orig_img) for _ in range(4)] plot(invertered_imgs) -#################################### +# %% # RandomPosterize # ~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomPosterize` transform @@ -214,7 +214,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): posterized_imgs = [posterizer(orig_img) for _ in range(4)] plot(posterized_imgs) -#################################### +# %% # RandomSolarize # ~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomSolarize` transform @@ -225,7 +225,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): solarized_imgs = [solarizer(orig_img) for _ in range(4)] plot(solarized_imgs) -#################################### +# %% # RandomAdjustSharpness # ~~~~~~~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomAdjustSharpness` transform @@ -235,7 +235,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)] plot(sharpened_imgs) -#################################### +# %% # RandomAutocontrast # ~~~~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomAutocontrast` transform @@ -245,7 +245,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)] plot(autocontrasted_imgs) -#################################### +# %% # RandomEqualize # ~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomEqualize` transform @@ -255,7 +255,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): equalized_imgs = [equalizer(orig_img) for _ in range(4)] plot(equalized_imgs) -#################################### +# %% # Augmentation Transforms # ----------------------- # The following transforms are combinations of multiple transforms, @@ -275,7 +275,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): row_title = [str(policy).split('.')[-1] for policy in policies] plot(imgs, row_title=row_title) -#################################### +# %% # RandAugment # ~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandAugment` is an alternate version of AutoAugment. @@ -283,7 +283,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): imgs = [augmenter(orig_img) for _ in range(4)] plot(imgs) -#################################### +# %% # TrivialAugmentWide # ~~~~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.TrivialAugmentWide` is an alternate implementation of AutoAugment. @@ -293,7 +293,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): imgs = [augmenter(orig_img) for _ in range(4)] plot(imgs) -#################################### +# %% # AugMix # ~~~~~~ # The :class:`~torchvision.transforms.AugMix` transform interpolates between augmented versions of an image. @@ -301,7 +301,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): imgs = [augmenter(orig_img) for _ in range(4)] plot(imgs) -#################################### +# %% # Randomly-applied Transforms # --------------------------- # @@ -318,7 +318,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): transformed_imgs = [hflipper(orig_img) for _ in range(4)] plot(transformed_imgs) -#################################### +# %% # RandomVerticalFlip # ~~~~~~~~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomVerticalFlip` transform @@ -328,7 +328,7 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): transformed_imgs = [vflipper(orig_img) for _ in range(4)] plot(transformed_imgs) -#################################### +# %% # RandomApply # ~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomApply` transform diff --git a/gallery/plot_transforms_v2.py b/gallery/plot_transforms_v2.py index 88916ba44f9..b85481ae1a5 100644 --- a/gallery/plot_transforms_v2.py +++ b/gallery/plot_transforms_v2.py @@ -36,7 +36,7 @@ def load_data(): return path, image, bounding_boxes, masks, labels -######################################################################################################################## +# %% # The :mod:`torchvision.transforms.v2` API supports images, videos, bounding boxes, and instance and segmentation # masks. Thus, it offers native support for many Computer Vision tasks, like image and video classification, object # detection or instance and semantic segmentation. Still, the interface is the same, making @@ -55,7 +55,7 @@ def load_data(): ] ) -######################################################################################################################## +# %% # :mod:`torchvision.transforms.v2` natively supports jointly transforming multiple inputs while making sure that # potential random behavior is consistent across all inputs. However, it doesn't enforce a specific input structure or # order. @@ -70,7 +70,7 @@ def load_data(): ) # Instance Segmentation new_image, new_target = transform((image, {"boxes": bounding_boxes, "labels": labels})) # Arbitrary Structure -######################################################################################################################## +# %% # Under the hood, :mod:`torchvision.transforms.v2` relies on :mod:`torchvision.datapoints` for the dispatch to the # appropriate function for the input data: :ref:`sphx_glr_auto_examples_plot_datapoints.py`. Note however, that as # regular user, you likely don't have to touch this yourself. See @@ -84,7 +84,7 @@ def load_data(): assert new_sample["path"] is sample["path"] -######################################################################################################################## +# %% # As stated above, :mod:`torchvision.transforms.v2` is a drop-in replacement for :mod:`torchvision.transforms` and thus # also supports transforming plain :class:`torch.Tensor`'s as image or video if applicable. This is achieved with a # simple heuristic: diff --git a/gallery/plot_transforms_v2_e2e.py b/gallery/plot_transforms_v2_e2e.py index 981b1e58832..8a80c78e1f7 100644 --- a/gallery/plot_transforms_v2_e2e.py +++ b/gallery/plot_transforms_v2_e2e.py @@ -47,7 +47,7 @@ def show(sample): import torchvision.transforms.v2 as transforms -######################################################################################################################## +# %% # We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently # returns, and we'll see how to convert it to a format that is compatible with our new transforms. @@ -67,7 +67,7 @@ def load_example_coco_detection_dataset(**kwargs): print(type(target), type(target[0]), list(target[0].keys())) -######################################################################################################################## +# %% # The dataset returns a two-tuple with the first item being a :class:`PIL.Image.Image` and second one a list of # dictionaries, which each containing the annotations for a single object instance. As is, this format is not compatible # with the ``torchvision.transforms.v2``, nor with the models. To overcome that, we provide the @@ -85,13 +85,13 @@ def load_example_coco_detection_dataset(**kwargs): print(type(target), list(target.keys())) print(type(target["boxes"]), type(target["labels"])) -######################################################################################################################## +# %% # As baseline, let's have a look at a sample without transformations: show(sample) -######################################################################################################################## +# %% # With the dataset properly set up, we can now define the augmentation pipeline. This is done the same way it is done in # ``torchvision.transforms`` v1, but now handles bounding boxes and masks without any extra configuration. @@ -107,7 +107,7 @@ def load_example_coco_detection_dataset(**kwargs): ] ) -######################################################################################################################## +# %% # .. note:: # Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it # should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as @@ -126,7 +126,7 @@ def load_example_coco_detection_dataset(**kwargs): show(sample) -######################################################################################################################## +# %% # We can see that the color of the image was distorted, we zoomed out on it (off center) and flipped it horizontally. # In all of this, the bounding box was transformed accordingly. And without any further ado, we can start training. diff --git a/gallery/plot_video_api.py b/gallery/plot_video_api.py index 76e3590d57c..aa3a620a613 100644 --- a/gallery/plot_video_api.py +++ b/gallery/plot_video_api.py @@ -7,14 +7,14 @@ videos, together with the examples on how to build datasets and more. """ -#################################### +# %% # 1. Introduction: building a new video object and examining the properties # ------------------------------------------------------------------------- # First we select a video to test the object out. For the sake of argument # we're using one from kinetics400 dataset. # To create it, we need to define the path and the stream we want to use. -###################################### +# %% # Chosen video statistics: # # - WUzgd7C1pWA.mp4 @@ -42,7 +42,7 @@ ) video_path = "./WUzgd7C1pWA.mp4" -###################################### +# %% # Streams are defined in a similar fashion as torch devices. We encode them as strings in a form # of ``stream_type:stream_id`` where ``stream_type`` is a string and ``stream_id`` a long int. # The constructor accepts passing a ``stream_type`` only, in which case the stream is auto-discovered. @@ -52,7 +52,7 @@ video = torchvision.io.VideoReader(video_path, stream) video.get_metadata() -###################################### +# %% # Here we can see that video has two streams - a video and an audio stream. # Currently available stream types include ['video', 'audio']. # Each descriptor consists of two parts: stream type (e.g. 'video') and a unique stream id @@ -61,7 +61,7 @@ # users can access the one they want. # If only stream type is passed, the decoder auto-detects first stream of that type and returns it. -###################################### +# %% # Let's read all the frames from the video stream. By default, the return value of # ``next(video_reader)`` is a dict containing the following fields. # @@ -85,7 +85,7 @@ print("Approx total number of datapoints we can expect: ", approx_nf) print("Read data size: ", frames[0].size(0) * len(frames)) -###################################### +# %% # But what if we only want to read certain time segment of the video? # That can be done easily using the combination of our ``seek`` function, and the fact that each call # to next returns the presentation timestamp of the returned frame in seconds. @@ -107,7 +107,7 @@ print("Total number of frames: ", len(frames)) -###################################### +# %% # Or if we wanted to read from 2nd to 5th second, # We seek into a second second of the video, # then we utilize the itertools takewhile to get the @@ -125,7 +125,7 @@ print("We can expect approx: ", approx_nf) print("Tensor size: ", frames[0].size()) -#################################### +# %% # 2. Building a sample read_video function # ---------------------------------------------------------------------------------------- # We can utilize the methods above to build the read video function that follows @@ -170,21 +170,21 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au vf, af, info, meta = example_read_video(video) print(vf.size(), af.size()) -#################################### +# %% # 3. Building an example randomly sampled dataset (can be applied to training dataset of kinetics400) # ------------------------------------------------------------------------------------------------------- # Cool, so now we can use the same principle to make the sample dataset. # We suggest trying out iterable dataset for this purpose. # Here, we are going to build an example dataset that reads randomly selected 10 frames of video. -#################################### +# %% # Make sample dataset import os os.makedirs("./dataset", exist_ok=True) os.makedirs("./dataset/1", exist_ok=True) os.makedirs("./dataset/2", exist_ok=True) -#################################### +# %% # Download the videos from torchvision.datasets.utils import download_url download_url( @@ -212,7 +212,7 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au "v_SoccerJuggling_g24_c01.avi" ) -#################################### +# %% # Housekeeping and utilities import os import random @@ -232,7 +232,7 @@ def get_samples(root, extensions=(".mp4", ".avi")): _, class_to_idx = _find_classes(root) return make_dataset(root, class_to_idx, extensions=extensions) -#################################### +# %% # We are going to define the dataset and some basic arguments. # We assume the structure of the FolderDataset, and add the following parameters: # @@ -287,7 +287,7 @@ def __iter__(self): 'end': current_pts} yield output -#################################### +# %% # Given a path of videos in a folder structure, i.e: # # - dataset @@ -309,7 +309,7 @@ def __iter__(self): dataset = RandomDataset("./dataset", epoch_size=None, frame_transform=frame_transform) -#################################### +# %% from torch.utils.data import DataLoader loader = DataLoader(dataset, batch_size=12) data = {"video": [], 'start': [], 'end': [], 'tensorsize': []} @@ -321,7 +321,7 @@ def __iter__(self): data['tensorsize'].append(batch['video'][i].size()) print(data) -#################################### +# %% # 4. Data Visualization # ---------------------------------- # Example of visualized video @@ -334,7 +334,7 @@ def __iter__(self): plt.imshow(batch["video"][0, i, ...].permute(1, 2, 0)) plt.axis("off") -#################################### +# %% # Cleanup the video and dataset: import os import shutil diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index d6350a7a4c4..5e629cb8cb8 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -30,7 +30,7 @@ def show(imgs): axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) -#################################### +# %% # Visualizing a grid of images # ---------------------------- # The :func:`~torchvision.utils.make_grid` function can be used to create a @@ -48,7 +48,7 @@ def show(imgs): grid = make_grid(dog_list) show(grid) -#################################### +# %% # Visualizing bounding boxes # -------------------------- # We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an @@ -64,7 +64,7 @@ def show(imgs): show(result) -##################################### +# %% # Naturally, we can also plot bounding boxes produced by torchvision detection # models. Here is a demo with a Faster R-CNN model loaded from # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` @@ -85,7 +85,7 @@ def show(imgs): outputs = model(images) print(outputs) -##################################### +# %% # Let's plot the boxes detected by our model. We will only plot the boxes with a # score greater than a given threshold. @@ -96,7 +96,7 @@ def show(imgs): ] show(dogs_with_boxes) -##################################### +# %% # Visualizing segmentation masks # ------------------------------ # The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to @@ -125,7 +125,7 @@ def show(imgs): output = model(batch)['out'] print(output.shape, output.min().item(), output.max().item()) -##################################### +# %% # As we can see above, the output of the segmentation model is a tensor of shape # ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and # we can normalize them into ``[0, 1]`` by using a softmax. After the softmax, @@ -147,7 +147,7 @@ def show(imgs): show(dog_and_boat_masks) -##################################### +# %% # As expected, the model is confident about the dog class, but not so much for # the boat class. # @@ -162,7 +162,7 @@ def show(imgs): show([m.float() for m in boolean_dog_masks]) -##################################### +# %% # The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you # can read it as the following query: "For which pixels is 'dog' the most likely # class?" @@ -184,7 +184,7 @@ def show(imgs): ] show(dogs_with_masks) -##################################### +# %% # We can plot more than one mask per image! Remember that the model returned as # many masks as there are classes. Let's ask the same query as above, but this # time for *all* classes, not just the dog class: "For each pixel and each class @@ -204,7 +204,7 @@ def show(imgs): dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.6) show(dog_with_all_masks) -##################################### +# %% # We can see in the image above that only 2 masks were drawn: the mask for the # background and the mask for the dog. This is because the model thinks that # only these 2 classes are the most likely ones across all the pixels. If the @@ -231,7 +231,7 @@ def show(imgs): show(dogs_with_masks) -##################################### +# %% # .. _instance_seg_output: # # Instance segmentation models @@ -265,7 +265,7 @@ def show(imgs): output = model(images) print(output) -##################################### +# %% # Let's break this down. For each image in the batch, the model outputs some # detections (or instances). The number of detections varies for each input # image. Each instance is described by its bounding box, its label, its score @@ -288,7 +288,7 @@ def show(imgs): print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, " f"min = {dog1_masks.min()}, max = {dog1_masks.max()}") -##################################### +# %% # Here the masks correspond to probabilities indicating, for each pixel, how # likely it is to belong to the predicted label of that instance. Those # predicted labels correspond to the 'labels' element in the same output dict. @@ -297,7 +297,7 @@ def show(imgs): print("For the first dog, the following instances were detected:") print([weights.meta["categories"][label] for label in dog1_output['labels']]) -##################################### +# %% # Interestingly, the model detects two persons in the image. Let's go ahead and # plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks` # expects boolean masks, we need to convert those probabilities into boolean @@ -315,14 +315,14 @@ def show(imgs): show(draw_segmentation_masks(dog1_int, dog1_bool_masks, alpha=0.9)) -##################################### +# %% # The model seems to have properly detected the dog, but it also confused trees # with people. Looking more closely at the scores will help us plot more # relevant masks: print(dog1_output['scores']) -##################################### +# %% # Clearly the model is more confident about the dog detection than it is about # the people detections. That's good news. When plotting the masks, we can ask # for only those that have a good score. Let's use a score threshold of .75 @@ -341,12 +341,12 @@ def show(imgs): ] show(dogs_with_masks) -##################################### +# %% # The two 'people' masks in the first image where not selected because they have # a lower score than the score threshold. Similarly, in the second image, the # instance with class 15 (which corresponds to 'bench') was not selected. -##################################### +# %% # .. _keypoint_output: # # Visualizing keypoints @@ -373,7 +373,7 @@ def show(imgs): outputs = model([person_float]) print(outputs) -##################################### +# %% # As we see the output contains a list of dictionaries. # The output list is of length batch_size. # We currently have just a single image so length of list is 1. @@ -388,7 +388,7 @@ def show(imgs): print(kpts) print(scores) -##################################### +# %% # The KeypointRCNN model detects there are two instances in the image. # If you plot the boxes by using :func:`~draw_bounding_boxes` # you would recognize they are the person and the surfboard. @@ -402,7 +402,7 @@ def show(imgs): print(keypoints) -##################################### +# %% # Great, now we have the keypoints corresponding to the person. # Each keypoint is represented by x, y coordinates and the visibility. # We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints. @@ -413,7 +413,7 @@ def show(imgs): res = draw_keypoints(person_int, keypoints, colors="blue", radius=3) show(res) -##################################### +# %% # As we see the keypoints appear as colored circles over the image. # The coco keypoints for a person are ordered and represent the following list.\ @@ -424,7 +424,7 @@ def show(imgs): "left_knee", "right_knee", "left_ankle", "right_ankle", ] -##################################### +# %% # What if we are interested in joining the keypoints? # This is especially useful in creating pose detection or action recognition. # We can join the keypoints easily using the `connectivity` parameter. @@ -450,7 +450,7 @@ def show(imgs): (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16) ] -##################################### +# %% # We pass the above list to the connectivity parameter to connect the keypoints. # From d1d8aa9603e8b82c7b143a0e4eb113b2ca60fbf7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Aug 2023 17:46:01 +0100 Subject: [PATCH 02/11] Add gallery tuto for custom transforms --- gallery/plot_custom_datapoints.py | 7 ++ gallery/plot_custom_transforms.py | 122 ++++++++++++++++++++++++++++++ gallery/plot_datapoints.py | 90 ++++++++++++++++------ 3 files changed, 194 insertions(+), 25 deletions(-) create mode 100644 gallery/plot_custom_datapoints.py create mode 100644 gallery/plot_custom_transforms.py diff --git a/gallery/plot_custom_datapoints.py b/gallery/plot_custom_datapoints.py new file mode 100644 index 00000000000..1642936150f --- /dev/null +++ b/gallery/plot_custom_datapoints.py @@ -0,0 +1,7 @@ +""" +===================================== +How to write your own Datapoint class +===================================== + +TODO +""" diff --git a/gallery/plot_custom_transforms.py b/gallery/plot_custom_transforms.py new file mode 100644 index 00000000000..9d7e2508a6f --- /dev/null +++ b/gallery/plot_custom_transforms.py @@ -0,0 +1,122 @@ +""" +================================ +How to write your own transforms +================================ + +This guide explains how to write transforms that are compatible with the +torchvision transforms V2 API. +""" + +# %% +import torch +import torchvision + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision import datapoints +from torchvision.transforms import v2 + + +# %% +# Just create a ``nn.Module`` and override the ``forward`` method +# =============================================================== +# +# In most cases, this is all you're going to need, as long as you already know +# the structure of the input that your transform will expect. For example if +# you're just doing image classification, your transform will typically accept a +# single image as input, or a ``(img, label)`` input. So you can just hard-code +# your ``forward`` method to accept just that, e.g. +# +# .. code:: python +# +# class MyCustomTransform(torch.nn.Module): +# def forward(self, img, label): +# # Do some transformations +# return new_img, new_label +# +# .. note:: +# +# This means that if you have a custom transform that is already compatible +# with the V1 transforms (those in ``torchvision.transforms``), it will +# still work with the V2 transforms without any change! +# +# We will illustrate this more completely below with a typical detection case, +# where our samples are just images, bounding boxes and labels: + +class MyCustomTransform(torch.nn.Module): + def forward(self, img, bboxes, label): # we assume inputs are always structured like this + print( + f"I'm transforming an image of shape {img.shape} " + f"with bboxes = {bboxes}\n{label = }" + ) + # Do some transformations. Here, we're just passing though the input + return img, bboxes, label + +transforms = v2.Compose([ + MyCustomTransform(), + v2.RandomResizedCrop((224, 224), antialias=True), + v2.RandomHorizontalFlip(p=1), + v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) +]) + +H, W = 256, 256 +img = torch.rand(3, H, W) +bboxes = datapoints.BoundingBoxes( + torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]), + format="XYXY", + canvas_size=(H, W) +) +label = 3 + +out_img, out_bboxes, out_label = transforms(img, bboxes, label) +# %% +print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") +# %% +# .. note:: +# As you're maniupulate datapoint classes in your code, make sure to +# familiarize yourself with this section: +# :ref:`datapoint_unwrapping_behaviour` +# +# Supporting arbitrary input structures +# ===================================== +# +# In the section above, we have assumed that you already know the structure of +# your inputs and that you're OK with hard-coding this expected structure in +# your code. If you want your custom transforms to be as flexible as possible, +# this can be a bit limitting. +# +# A key feature of the builtin Torchvision V2 transforms is that they can accept +# arbitrary input structure and return the same structure as output (with +# transformed entries). For example, transforms can accept a single image, or a +# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: + +structured_input = { + "img": img, + "annotations": (bboxes, label), + "something_that_will_be_ignored": (1, "hello") +} +structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) + +assert isinstance(structured_output, dict) +assert structured_output["something_that_will_be_ignored"] == (1, "hello") +print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") + +# %% +# If you want to reproduce this behavior in your own transform, we invite you to +# look at our `code +# `_ +# and adapt it to your needs. +# +# In brief, the core logic is to unpack the input into a flat list using `pytree +# `_, and +# then transform only the entries that can be transformed (the decision is made +# based on the **class** of the entries, as all datapoints are +# tensor-subclasses) + some custom logic that is out of score here - check the +# code for details. The (potentially transformed) entries are then repacked and +# returned, in the same structure as the input. +# +# We do not provide public dev-facing tools to achieve that at this time, but if +# this is something that would be valuable to you, please let us know by opening +# an issue on our `GitHub repo `_. diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index 57e29bd86eb..c74cb312163 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -3,13 +3,22 @@ Datapoints FAQ ============== -The :mod:`torchvision.datapoints` namespace was introduced together with ``torchvision.transforms.v2``. This example -showcases what these datapoints are and how they behave. This is a fairly low-level topic that most users will not need -to worry about: you do not need to understand the internals of datapoints to efficiently rely on -``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets, -transforms, or work directly with the datapoints. +Datapoints are Tensor subclasses introduced together with +``torchvision.transforms.v2``. This example showcases what these datapoints are +and how they behave. + +.. warning:: + + **Intended Audience** Unless you're writing your own transforms or your own datapoints, you + probably do not need to read this guide. This is a fairly low-level topic + that most users will not need to worry about: you do not need to understand + the internals of datapoints to efficiently rely on + ``torchvision.transforms.v2``. It may however be useful for advanced users + trying to implement their own datasets, transforms, or work directly with + the datapoints. """ +# %% import PIL.Image import torch @@ -35,11 +44,20 @@ assert isinstance(image, torch.Tensor) assert image.data_ptr() == tensor.data_ptr() - # %% # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # for the input data. # +# What can I do with a datapoint? +# ------------------------------- +# +# Datapoints look and feel just like regular tensors - they **are** tensors. +# Everything that is supported on a plain :class:`torch.Tensor` like `.sum()` or +# any torch.*` operator will also works on datapoints. See +# :ref:`datapoint_unwrapping_behaviour` for more details. + +# %% +# # What datapoints are supported? # ------------------------------ # @@ -79,10 +97,10 @@ # :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the # corresponding image alongside the actual values: -bounding_box = datapoints.BoundingBoxes( - [17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] +bboxes = datapoints.BoundingBoxes( + [[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] ) -print(bounding_box) +print(bboxes) # %% @@ -105,8 +123,8 @@ class PennFudanDataset(torch.utils.data.Dataset): def __getitem__(self, item): ... - target["boxes"] = datapoints.BoundingBoxes( - boxes, + target["bboxes"] = datapoints.BoundingBoxes( + bboxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=F.get_size(img), ) @@ -147,7 +165,7 @@ def get_transform(train): # %% # .. note:: # -# If both :class:`~torchvision.datapoints.BoundingBoxes`'es and :class:`~torchvision.datapoints.Mask`'s are included in +# If both :class:`~torchvision.datapoints.BoundingBoxes` and :class:`~torchvision.datapoints.Mask`'s are included in # the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or # at least not wrapping the obsolete parts, can lead to a significant performance boost. # @@ -156,41 +174,63 @@ def get_transform(train): # even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are # generated from the masks. # -# How do the datapoints behave inside a computation? -# -------------------------------------------------- +# .. _datapoint_unwrapping_behaviour: # -# Datapoints look and feel just like regular tensors. Everything that is supported on a plain :class:`torch.Tensor` -# also works on datapoints. -# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the -# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below): +# I had a Datapoint but now I have a Tensor. Help! +# ------------------------------------------------ +# +# For a lot of operations involving datapoints, we cannot safely infer whether +# the result should retain the datapoint type, so we choose to return a plain +# tensor instead of a datapoint (this might change, see note below): -assert isinstance(image, datapoints.Image) +assert isinstance(bboxes, datapoints.BoundingBoxes) -new_image = image + 0 +# Shift bboxes by 3 pixels in both H and W +new_bboxes = bboxes + 3 -assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) +assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes) + +# %% +# If you're writing your own custom transforms or code involving datapoints, you +# can re-wrap the output into a datapoint by just calling their constructor, or +# by using the ``.wrap_like()`` class method: + +new_bboxes = bboxes + 3 +new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) +assert isinstance(new_bboxes, datapoints.BoundingBoxes) # %% # .. note:: # +# You never need to re-wrap manually if you're using the built-in transforms +# or their functional equivalents, because this logic is taken care of for +# you. +# +# .. note:: +# # This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you # have any suggestions on how to better support your use-cases, please reach out to us via this issue: # https://github.com/pytorch/vision/issues/7319 # -# There are two exceptions to this rule: +# There are two exceptions to this "unwrapping" rule: # # 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_` # retain the datapoint type. -# 2. Inplace operations on datapoints cannot change the type of the datapoint they are called on. However, if you use -# the flow style, the returned value will be unwrapped: +# 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However, +# the **returned** value of inplace operations will be unwrapped into a pure +# tensor: image = datapoints.Image([[[0, 1], [1, 0]]]) new_image = image.add_(1).mul_(2) -assert isinstance(image, torch.Tensor) +# image got transformed in-place and is still an Image datapoint, but new_image +# is a Tensor. They share the same underlying data and they're equal, just +# different classes. +assert isinstance(image, datapoints.Image) print(image) assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, datapoints.Image) assert (new_image == image).all() +assert new_image.data_ptr() == image.data_ptr() From 937069d53f992d79d6a8e2824ddeb94e7eb3b0ab Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Aug 2023 18:12:20 +0100 Subject: [PATCH 03/11] More stuff --- gallery/plot_datapoints.py | 44 ++++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index c74cb312163..5b0de58e8e3 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -52,9 +52,9 @@ # ------------------------------- # # Datapoints look and feel just like regular tensors - they **are** tensors. -# Everything that is supported on a plain :class:`torch.Tensor` like `.sum()` or -# any torch.*` operator will also works on datapoints. See -# :ref:`datapoint_unwrapping_behaviour` for more details. +# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or +# any ``torch.*`` operator will also works on datapoints. See +# :ref:`datapoint_unwrapping_behaviour` for a few gotchas. # %% # @@ -68,9 +68,14 @@ # * :class:`~torchvision.datapoints.BoundingBoxes` # * :class:`~torchvision.datapoints.Mask` # +# .. _datapoint_creation: +# # How do I construct a datapoint? # ------------------------------- # +# Using the constructor +# ^^^^^^^^^^^^^^^^^^^^^ +# # Each datapoint class takes any tensor-like data that can be turned into a :class:`~torch.Tensor` image = datapoints.Image([[[[0, 1], [1, 0]]]]) @@ -86,27 +91,50 @@ # %% -# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` also take a +# In addition, :class:`~torchvision.datapoints.Image` and :class:`~torchvision.datapoints.Mask` can also take a # :class:`PIL.Image.Image` directly: image = datapoints.Image(PIL.Image.open("assets/astronaut.jpg")) print(image.shape, image.dtype) # %% -# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example, -# :class:`~torchvision.datapoints.BoundingBoxes` stores the coordinate format as well as the spatial size of the -# corresponding image alongside the actual values: +# Some datapoints require additional metadata to be passed in ordered to be constructed. For example, +# :class:`~torchvision.datapoints.BoundingBoxes` requires the coordinate format as well as the size of the +# corresponding image (``canvas_size``) alongside the actual values. These +# metadata are required to properly transform the bounding boxes. bboxes = datapoints.BoundingBoxes( [[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] ) print(bboxes) +# %% +# Using the ``wrap_like()`` class method +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# You can also use the ``wrap_like()`` class method to wrap a tensor-like object +# into a datapoint. This is useful when you already have an object of the +# desired type, which typically happens when writing transforms: you just want +# to wrap the output like the input. This API is inspired by utils like +# :func:`torch.zeros_like`: + +new_bboxes = torch.tensor([0, 20, 30, 40]) +new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) +assert isinstance(new_bboxes, datapoints.BoundingBoxes) +assert new_bboxes.canvas_size == bboxes.canvas_size + # %% +# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass +# it as a parameter to override it. Check the +# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for +# more details. +# # Do I have to wrap the output of the datasets myself? # ---------------------------------------------------- # +# TODO: Move this in another guide - this is user-facing, not dev-facing. +# # Only if you are using custom datasets. For the built-in ones, you can use # :func:`torchvision.datasets.wrap_dataset_for_transforms_v2`. Note that the function also supports subclasses of the # built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you @@ -201,6 +229,8 @@ def get_transform(train): assert isinstance(new_bboxes, datapoints.BoundingBoxes) # %% +# See more details above in :ref:`datapoint_creation`. +# # .. note:: # # You never need to re-wrap manually if you're using the built-in transforms From 5780fd7820dfea267cd2908ef12eb167ead41e19 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Aug 2023 18:15:21 +0100 Subject: [PATCH 04/11] not tensor-like --- gallery/plot_datapoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index 5b0de58e8e3..33d5506d99f 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -112,7 +112,7 @@ # Using the ``wrap_like()`` class method # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# You can also use the ``wrap_like()`` class method to wrap a tensor-like object +# You can also use the ``wrap_like()`` class method to wrap a tensor object # into a datapoint. This is useful when you already have an object of the # desired type, which typically happens when writing transforms: you just want # to wrap the output like the input. This API is inspired by utils like From cbff036585b73880a40dd8bcd0090a84dd39ddd9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Aug 2023 21:06:27 +0100 Subject: [PATCH 05/11] Some more --- docs/source/conf.py | 2 +- docs/source/datapoints.rst | 1 + gallery/plot_custom_datapoints.py | 50 +++++++++++++++++++++++++++- torchvision/datapoints/_datapoint.py | 6 ++++ 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7b3e9e8a7f3..fed3884ea27 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -320,7 +320,7 @@ def inject_weight_metadata(app, what, name, obj, options, lines): used within the autoclass directive. """ - if obj.__name__.endswith(("_Weights", "_QuantizedWeights")): + if getattr(obj, ".__name__", "").endswith(("_Weights", "_QuantizedWeights")): if len(obj) == 0: lines[:] = ["There are no available pre-trained weights."] diff --git a/docs/source/datapoints.rst b/docs/source/datapoints.rst index 55d3cda4a8c..ea23a7ff7a6 100644 --- a/docs/source/datapoints.rst +++ b/docs/source/datapoints.rst @@ -17,3 +17,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`. BoundingBoxFormat BoundingBoxes Mask + Datapoint diff --git a/gallery/plot_custom_datapoints.py b/gallery/plot_custom_datapoints.py index 1642936150f..fb4e11f4e55 100644 --- a/gallery/plot_custom_datapoints.py +++ b/gallery/plot_custom_datapoints.py @@ -3,5 +3,53 @@ How to write your own Datapoint class ===================================== -TODO +This guide is intended for downstream library maintainers. We explain how to +write your own datapoint class, and how to make it compatible with the built-in +Torchvision V2 transforms. Before continuing, make sure you have read +:ref:`sphx_glr_auto_examples_plot_datapoints.py`. """ + +# %% +import torch +import torchvision + +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that +# some APIs may slightly change in the future +torchvision.disable_beta_transforms_warning() + +from torchvision import datapoints +from torchvision.transforms import v2 + +# %% +# +# We will just create a very simple class that just inherits from the base +# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover +# what you need to know to implement your own custom uses-cases. If you need to +# create a class that carries meta-data, take a look at how the +# :class:`~torchvision.datapoints.BoundingBoxes` class is implemented. + +class MyDatapoint(datapoints.Datapoint): + pass + +my_dp = MyDatapoint([1, 2, 3]) +my_dp + +#%% +from torchvision.transforms.v2.functional import register_kernel, resize + +# TODO: THIS didn't raise an error: +# @register_kernel(MyDatapoint, resize) + +# TODO Let devs pass strings + +@register_kernel(resize, MyDatapoint) +def resize_my_datapoint(my_dp, size, *args, **kwargs): + print(f"Resizing {my_dp} to {size}") + return torch.rand(3, *size) + + + +my_dp = MyDatapoint(torch.rand(3, 256, 256)) +out = v2.Resize((224, 224))(my_dp) +print(type(out), out.shape) +# %% diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 384273301de..cc428ab5996 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -14,6 +14,12 @@ class Datapoint(torch.Tensor): + """[Beta] Base class for all datapoints. + + You probably don't want to use this class unless you're defining your own + custom Datapoints. See + :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for details. + """ @staticmethod def _to_tensor( data: Any, From 43a71e74e2136d0d57442d64199507c8512bb9b9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 2 Aug 2023 21:43:07 +0100 Subject: [PATCH 06/11] Allow register_kernel to take name as input --- test/test_transforms_v2_refactored.py | 33 +++++++++++++++++++ .../transforms/v2/functional/_utils.py | 15 +++++++++ 2 files changed, 48 insertions(+) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 45668fda1ca..8a858bf58c2 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2181,3 +2181,36 @@ def test_unsupported_types(self, dispatcher, make_input): with pytest.raises(TypeError, match=re.escape(str(type(input)))): dispatcher(input) + + +class TestRegisterKernel: + @pytest.mark.parametrize("dispatcher", (F.resize, "resize")) + def test_register_kernel(self, dispatcher): + class CustomDatapoint(datapoints.Datapoint): + pass + + kernel_was_called = False + + @F.register_kernel(dispatcher, CustomDatapoint) + def new_resize(dp, *args, **kwargs): + nonlocal kernel_was_called + kernel_was_called = True + return dp + + t = transforms.Resize(size=(224, 224), antialias=True) + + my_dp = CustomDatapoint(torch.rand(3, 10, 10)) + out = t(my_dp) + assert out is my_dp + assert kernel_was_called + + # Sanity check to make sure we didn't override the kernel of other types + t(torch.rand(3, 10, 10)).shape == (3, 224, 224) + t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224) + + def test_bad_disaptcher_name(self): + class CustomDatapoint(datapoints.Datapoint): + pass + + with pytest.raises(ValueError, match="Could not find dispatcher with name"): + F.register_kernel("bad_name", CustomDatapoint) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 63e029d6c77..b4798bfa5d6 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -37,7 +37,22 @@ def decorator(kernel): return decorator +def _name_to_dispatcher(name): + import torchvision.transforms.v2.functional # noqa + + try: + return next( + obj + for obj in torchvision.transforms.v2.functional.__dict__.values() + if getattr(obj, "__name__", "") == name + ) + except StopIteration: + raise ValueError(f"Could not find dispatcher with name '{name}'.") + + def register_kernel(dispatcher, datapoint_cls): + if isinstance(dispatcher, str): + dispatcher = _name_to_dispatcher(name=dispatcher) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) From 99ea401bba6ae0b5d1d631aad0f587dadea2d3e2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Aug 2023 09:54:33 +0100 Subject: [PATCH 07/11] Better _name_to_dispatcher --- torchvision/transforms/v2/functional/_utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index b4798bfa5d6..1eaa54102a4 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -41,13 +41,9 @@ def _name_to_dispatcher(name): import torchvision.transforms.v2.functional # noqa try: - return next( - obj - for obj in torchvision.transforms.v2.functional.__dict__.values() - if getattr(obj, "__name__", "") == name - ) - except StopIteration: - raise ValueError(f"Could not find dispatcher with name '{name}'.") + return getattr(torchvision.transforms.v2.functional, name) + except AttributeError: + raise ValueError(f"Could not find dispatcher with name '{name}'.") from None def register_kernel(dispatcher, datapoint_cls): From 88bf09b3302ad697e9a81493f7783e4ea5825ec2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Aug 2023 11:57:28 +0100 Subject: [PATCH 08/11] More docs --- docs/source/transforms.rst | 11 +++ gallery/plot_custom_datapoints.py | 98 ++++++++++++++++--- gallery/plot_custom_transforms.py | 1 + torchvision/datapoints/_bounding_box.py | 2 +- torchvision/datapoints/_datapoint.py | 8 +- torchvision/datapoints/_image.py | 9 -- torchvision/datapoints/_mask.py | 12 --- torchvision/datapoints/_video.py | 9 -- torchvision/prototype/datapoints/_label.py | 2 +- .../transforms/v2/functional/_utils.py | 5 + 10 files changed, 109 insertions(+), 48 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 73adb3cf3b5..a1858c6b514 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -375,3 +375,14 @@ you can use a functional transform to build transform classes with custom behavi to_pil_image to_tensor vflip + +Developer tools +--------------- + +.. currentmodule:: torchvision.transforms.v2.functional + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + register_kernel diff --git a/gallery/plot_custom_datapoints.py b/gallery/plot_custom_datapoints.py index fb4e11f4e55..c8d594045c9 100644 --- a/gallery/plot_custom_datapoints.py +++ b/gallery/plot_custom_datapoints.py @@ -21,35 +21,103 @@ from torchvision.transforms import v2 # %% -# -# We will just create a very simple class that just inherits from the base +# We will create a very simple class that just inherits from the base # :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover -# what you need to know to implement your own custom uses-cases. If you need to -# create a class that carries meta-data, take a look at how the +# what you need to know to implement your more elaborate uses-cases. If you need +# to create a class that carries meta-data, take a look at how the # :class:`~torchvision.datapoints.BoundingBoxes` class is implemented. + class MyDatapoint(datapoints.Datapoint): pass + my_dp = MyDatapoint([1, 2, 3]) my_dp -#%% -from torchvision.transforms.v2.functional import register_kernel, resize +# %% +# Now that we have defined our custom Datapoint class, we want it to be +# compatible with the built-in torchvision transforms, and the functional API. +# For that, we need to implement a kernel which performs the core of the +# transformation, and then "hook" it to the functional that we want to support +# via :func:`~torchvision.transforms.v2.functional.register_kernel`. +# +# We illustrate this process below: we create a kernel for the "horizontal flip" +# operation of our MyDatapoint class, and register it to the functional API. -# TODO: THIS didn't raise an error: -# @register_kernel(MyDatapoint, resize) +from torchvision.transforms.v2 import functional as F -# TODO Let devs pass strings -@register_kernel(resize, MyDatapoint) -def resize_my_datapoint(my_dp, size, *args, **kwargs): - print(f"Resizing {my_dp} to {size}") - return torch.rand(3, *size) +@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint) +def hflip_my_datapoint(my_dp, *args, **kwargs): + print("Flipping!") + out = my_dp.flip(-1) + return MyDatapoint.wrap_like(my_dp, out) +# %% +# To understand why ``wrap_like`` is used, see +# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now, +# we will explain it below in :ref:`param_forwarding`. +# +# .. note:: +# +# In our call to ``register_kernel`` above we used a string +# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We +# could also have used the functional *itself*, i.e. +# ``@register_kernel(dispatcher=F.hflip, ...)``. +# +# The functionals that you can be hooked into are the ones in +# ``torchvision.transforms.v2.functional`` and they are documented in +# :ref:`functional_transforms`. +# +# Now that we have registered our kernel, we can call the functional API on a +# ``MyDatapoint`` instance: my_dp = MyDatapoint(torch.rand(3, 256, 256)) -out = v2.Resize((224, 224))(my_dp) -print(type(out), out.shape) +_ = F.hflip(my_dp) + # %% +# And we can also use the +# :class:`~torchvision.transforms.v2.RandomHorizontalFlip` transform, since it relies on :func:`~torchvision.transforms.v2.functional.hflip` internally: +t = v2.RandomHorizontalFlip(p=1) +_ = t(my_dp) + +# %% +# .. note:: +# +# We cannot register a kernel for a transform class, we can only register a +# kernel for a **functional**. The reason we can't register a transform +# class is because one transform may internally rely on more than one +# functional, so in general we can't register a single kernel for a given +# class. +# +# .. _param_forwarding: +# +# Parameter forwarding, and ensuring future compatibility of your kernels +# ----------------------------------------------------------------------- +# +# The functional API that you're hooking into is public and therefore +# **backward** compatible: we guarantee that the parameters of these functionals +# won't be removed or renamed without a proper deprecation cycle. However, we +# don't guarantee **forward** compatibility, and we may add new parameters in +# the future. +# +# Imagine that in a future version, Torchvision adds a new ``inplace`` parameter +# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you +# already defined and registered your own kernel as + +def hflip_my_datapoint(my_dp): # noqa + print("Flipping!") + out = my_dp.flip(-1) + return MyDatapoint.wrap_like(my_dp, out) + + +# %% +# then calling ``F.hflip(my_dp)`` will **fail**, because ``hflip`` will try to +# pass the new ``inplace`` parameter to your kernel, but your kernel doesn't +# accept it. +# +# For this reason, we recommend to always define your kernels with +# ``*args, **kwargs`` in their signature, as done above. This way, your kernel +# will be able to accept any new parameter that we may add in the future. diff --git a/gallery/plot_custom_transforms.py b/gallery/plot_custom_transforms.py index 9d7e2508a6f..0308d0154a9 100644 --- a/gallery/plot_custom_transforms.py +++ b/gallery/plot_custom_transforms.py @@ -54,6 +54,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured # Do some transformations. Here, we're just passing though the input return img, bboxes, label + transforms = v2.Compose([ MyCustomTransform(), v2.RandomResizedCrop((224, 224), antialias=True), diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 912cc3bca08..7477b3652dc 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -42,7 +42,7 @@ class BoundingBoxes(Datapoint): canvas_size: Tuple[int, int] @classmethod - def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: + def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] bounding_boxes = tensor.as_subclass(cls) bounding_boxes.format = format bounding_boxes.canvas_size = canvas_size diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index cc428ab5996..6e6dd932cdc 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -20,6 +20,7 @@ class Datapoint(torch.Tensor): custom Datapoints. See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for details. """ + @staticmethod def _to_tensor( data: Any, @@ -31,9 +32,14 @@ def _to_tensor( requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) + @classmethod + def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: + image = tensor.as_subclass(cls) + return image + @classmethod def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: - raise NotImplementedError + return cls._wrap(tensor) _NO_WRAPPING_EXCEPTIONS = { torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index dccfc81a605..9b635e8e034 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -22,11 +22,6 @@ class Image(Datapoint): ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ - @classmethod - def _wrap(cls, tensor: torch.Tensor) -> Image: - image = tensor.as_subclass(cls) - return image - def __new__( cls, data: Any, @@ -48,10 +43,6 @@ def __new__( return cls._wrap(tensor) - @classmethod - def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image: - return cls._wrap(tensor) - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/datapoints/_mask.py b/torchvision/datapoints/_mask.py index 2b95eca72e2..95eda077929 100644 --- a/torchvision/datapoints/_mask.py +++ b/torchvision/datapoints/_mask.py @@ -22,10 +22,6 @@ class Mask(Datapoint): ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ - @classmethod - def _wrap(cls, tensor: torch.Tensor) -> Mask: - return tensor.as_subclass(cls) - def __new__( cls, data: Any, @@ -41,11 +37,3 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor) - - @classmethod - def wrap_like( - cls, - other: Mask, - tensor: torch.Tensor, - ) -> Mask: - return cls._wrap(tensor) diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index 11d6e2a854d..842c05bf7e9 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -20,11 +20,6 @@ class Video(Datapoint): ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ - @classmethod - def _wrap(cls, tensor: torch.Tensor) -> Video: - video = tensor.as_subclass(cls) - return video - def __new__( cls, data: Any, @@ -38,10 +33,6 @@ def __new__( raise ValueError return cls._wrap(tensor) - @classmethod - def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video: - return cls._wrap(tensor) - def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index 7ed2f7522b0..ac9b2d8912a 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -15,7 +15,7 @@ class _LabelBase(Datapoint): categories: Optional[Sequence[str]] @classmethod - def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: + def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override] label_base = tensor.as_subclass(cls) label_base.categories = categories return label_base diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 1eaa54102a4..2c30b78bebc 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -47,6 +47,11 @@ def _name_to_dispatcher(name): def register_kernel(dispatcher, datapoint_cls): + """Register a kernel for a dispatcher and a (custom) datapoint type. + + See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage + details. + """ if isinstance(dispatcher, str): dispatcher = _name_to_dispatcher(name=dispatcher) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) From bedb8580185ce386e4b92b1bb75c800db7b7d30f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Aug 2023 12:04:48 +0100 Subject: [PATCH 09/11] oops --- torchvision/datapoints/_datapoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 6e6dd932cdc..fae3c18656b 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -34,8 +34,7 @@ def _to_tensor( @classmethod def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: - image = tensor.as_subclass(cls) - return image + return tensor.as_subclass(cls) @classmethod def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: From 8178376de6efd335b8c8a11e3b39fb6df7c4ab8d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Aug 2023 10:13:20 +0100 Subject: [PATCH 10/11] Address comments --- gallery/plot_custom_datapoints.py | 6 ++++-- gallery/plot_custom_transforms.py | 10 +++++----- gallery/plot_datapoints.py | 13 ++++++++----- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/gallery/plot_custom_datapoints.py b/gallery/plot_custom_datapoints.py index c8d594045c9..ea757283e86 100644 --- a/gallery/plot_custom_datapoints.py +++ b/gallery/plot_custom_datapoints.py @@ -5,7 +5,7 @@ This guide is intended for downstream library maintainers. We explain how to write your own datapoint class, and how to make it compatible with the built-in -Torchvision V2 transforms. Before continuing, make sure you have read +Torchvision v2 transforms. Before continuing, make sure you have read :ref:`sphx_glr_auto_examples_plot_datapoints.py`. """ @@ -25,7 +25,8 @@ # :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover # what you need to know to implement your more elaborate uses-cases. If you need # to create a class that carries meta-data, take a look at how the -# :class:`~torchvision.datapoints.BoundingBoxes` class is implemented. +# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented +# `_. class MyDatapoint(datapoints.Datapoint): @@ -121,3 +122,4 @@ def hflip_my_datapoint(my_dp): # noqa # For this reason, we recommend to always define your kernels with # ``*args, **kwargs`` in their signature, as done above. This way, your kernel # will be able to accept any new parameter that we may add in the future. +# (Technically, adding `**kwargs` only should be enough). diff --git a/gallery/plot_custom_transforms.py b/gallery/plot_custom_transforms.py index 0308d0154a9..eba8e91faf4 100644 --- a/gallery/plot_custom_transforms.py +++ b/gallery/plot_custom_transforms.py @@ -1,7 +1,7 @@ """ -================================ -How to write your own transforms -================================ +=================================== +How to write your own v2 transforms +=================================== This guide explains how to write transforms that are compatible with the torchvision transforms V2 API. @@ -76,7 +76,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") # %% # .. note:: -# As you're maniupulate datapoint classes in your code, make sure to +# While working with datapoint classes in your code, make sure to # familiarize yourself with this section: # :ref:`datapoint_unwrapping_behaviour` # @@ -114,7 +114,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured # `_, and # then transform only the entries that can be transformed (the decision is made # based on the **class** of the entries, as all datapoints are -# tensor-subclasses) + some custom logic that is out of score here - check the +# tensor-subclasses) plus some custom logic that is out of score here - check the # code for details. The (potentially transformed) entries are then repacked and # returned, in the same structure as the input. # diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index 33d5506d99f..d87575cdb8e 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -104,7 +104,9 @@ # metadata are required to properly transform the bounding boxes. bboxes = datapoints.BoundingBoxes( - [[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:] + [[17, 16, 344, 495], [0, 10, 0, 10]], + format=datapoints.BoundingBoxFormat.XYXY, + canvas_size=image.shape[-2:] ) print(bboxes) @@ -234,7 +236,7 @@ def get_transform(train): # .. note:: # # You never need to re-wrap manually if you're using the built-in transforms -# or their functional equivalents, because this logic is taken care of for +# or their functional equivalents: this is automatically taken care of for # you. # # .. note:: @@ -243,10 +245,11 @@ def get_transform(train): # have any suggestions on how to better support your use-cases, please reach out to us via this issue: # https://github.com/pytorch/vision/issues/7319 # -# There are two exceptions to this "unwrapping" rule: +# There are a few exceptions to this "unwrapping" rule: # -# 1. The operations :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, and :meth:`~torch.Tensor.requires_grad_` -# retain the datapoint type. +# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, +# :meth:`torch.Tensor.detach` and :meth:`~torch.Tensor.requires_grad_` retain +# the datapoint type. # 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However, # the **returned** value of inplace operations will be unwrapped into a pure # tensor: From 5f58689af12acd7ba3c0b6fa7d6aed39413243f9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Aug 2023 14:02:29 +0100 Subject: [PATCH 11/11] Update torchvision/transforms/v2/functional/_utils.py Co-authored-by: Philip Meier --- torchvision/transforms/v2/functional/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 2c30b78bebc..bb3d59b551a 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -47,7 +47,7 @@ def _name_to_dispatcher(name): def register_kernel(dispatcher, datapoint_cls): - """Register a kernel for a dispatcher and a (custom) datapoint type. + """Decorate a kernel to register it for a dispatcher and a (custom) datapoint type. See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage details.