diff --git a/gallery/assets/person1.jpg b/gallery/assets/person1.jpg new file mode 100644 index 00000000000..83251c84a79 Binary files /dev/null and b/gallery/assets/person1.jpg differ diff --git a/gallery/assets/visualization_utils_thumbnail.png b/gallery/assets/visualization_utils_thumbnail.png deleted file mode 100644 index 63860a57ab5..00000000000 Binary files a/gallery/assets/visualization_utils_thumbnail.png and /dev/null differ diff --git a/gallery/assets/visualization_utils_thumbnail2.png b/gallery/assets/visualization_utils_thumbnail2.png new file mode 100644 index 00000000000..cf057e04207 Binary files /dev/null and b/gallery/assets/visualization_utils_thumbnail2.png differ diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 628319e52d5..59aeaa1ed37 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -4,10 +4,10 @@ ======================= This example illustrates some of the utilities that torchvision offers for -visualizing images, bounding boxes, and segmentation masks. +visualizing images, bounding boxes, segmentation masks and keypoints. """ -# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail.png" +# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail2.png" import torch import numpy as np @@ -366,3 +366,110 @@ def show(imgs): # The two 'people' masks in the first image where not selected because they have # a lower score than the score threshold. Similarly in the second image, the # instance with class 15 (which corresponds to 'bench') was not selected. + +##################################### +# Visualizing keypoints +# ------------------------------ +# The :func:`~torchvision.utils.draw_keypoints` function can be used to +# draw keypoints on images. We will see how to use it with +# torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`. +# We will first have a look at output of the model. +# +# Note that the keypoint detection model does not need normalized images. +# + +from torchvision.models.detection import keypointrcnn_resnet50_fpn +from torchvision.io import read_image + +person_int = read_image(str(Path("assets") / "person1.jpg")) +person_float = convert_image_dtype(person_int, dtype=torch.float) + +model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False) +model = model.eval() + +outputs = model([person_float]) +print(outputs) + +##################################### +# As we see the output contains a list of dictionaries. +# The output list is of length batch_size. +# We currently have just a single image so length of list is 1. +# Each entry in the list corresponds to an input image, +# and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`. +# Each value associated to those keys has `num_instances` elements in it. +# In our case above there are 2 instances detected in the image. + +kpts = outputs[0]['keypoints'] +scores = outputs[0]['scores'] + +print(kpts) +print(scores) + +##################################### +# The KeypointRCNN model detects there are two instances in the image. +# If you plot the boxes by using :func:`~draw_bounding_boxes` +# you would recognize they are the person and the surfboard. +# If we look at the scores, we will realize that the model is much more confident about the person than surfboard. +# We could now set a threshold confidence and plot instances which we are confident enough. +# Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person. + +detect_threshold = 0.75 +idx = torch.where(scores > detect_threshold) +keypoints = kpts[idx] + +print(keypoints) + +##################################### +# Great, now we have the keypoints corresponding to the person. +# Each keypoint is represented by x, y coordinates and the visibility. +# We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints. +# Note that the utility expects uint8 images. + +from torchvision.utils import draw_keypoints + +res = draw_keypoints(person_int, keypoints, colors="blue", radius=3) +show(res) + +##################################### +# As we see the keypoints appear as colored circles over the image. +# The coco keypoints for a person are ordered and represent the following list.\ + +coco_keypoints = [ + "nose", "left_eye", "right_eye", "left_ear", "right_ear", + "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", + "left_wrist", "right_wrist", "left_hip", "right_hip", + "left_knee", "right_knee", "left_ankle", "right_ankle", +] + +##################################### +# What if we are interested in joining the keypoints? +# This is especially useful in creating pose detection or action recognition. +# We can join the keypoints easily using the `connectivity` parameter. +# A close observation would reveal that we would need to join the points in below +# order to construct human skeleton. +# +# nose -> left_eye -> left_ear. (0, 1), (1, 3) +# +# nose -> right_eye -> right_ear. (0, 2), (2, 4) +# +# nose -> left_shoulder -> left_elbow -> left_wrist. (0, 5), (5, 7), (7, 9) +# +# nose -> right_shoulder -> right_elbow -> right_wrist. (0, 6), (6, 8), (8, 10) +# +# left_shoulder -> left_hip -> left_knee -> left_ankle. (5, 11), (11, 13), (13, 15) +# +# right_shoulder -> right_hip -> right_knee -> right_ankle. (6, 12), (12, 14), (14, 16) +# +# We will create a list containing these keypoint ids to be connected. + +connect_skeleton = [ + (0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8), + (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16) +] + +##################################### +# We pass the above list to the connectivity parameter to connect the keypoints. +# + +res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3) +show(res) diff --git a/test/test_utils.py b/test/test_utils.py index 64f45c697c6..30f144d8206 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -256,7 +256,14 @@ def test_draw_keypoints_vanilla(): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() - result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),)) + result = utils.draw_keypoints( + img, + keypoints, + colors="red", + connectivity=[ + (0, 1), + ], + ) path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png") if not os.path.exists(path): res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) @@ -277,7 +284,14 @@ def test_draw_keypoints_colored(colors): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img_cp = img.clone() - result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),)) + result = utils.draw_keypoints( + img, + keypoints, + colors=colors, + connectivity=[ + (0, 1), + ], + ) assert result.size(0) == 3 assert_equal(keypoints, keypoints_cp) assert_equal(img, img_cp) diff --git a/torchvision/utils.py b/torchvision/utils.py index 6c29767a7ce..b11f4ebeecf 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -304,7 +304,7 @@ def draw_segmentation_masks( def draw_keypoints( image: torch.Tensor, keypoints: torch.Tensor, - connectivity: Optional[Tuple[Tuple[int, int]]] = None, + connectivity: Optional[List[Tuple[int, int]]] = None, colors: Optional[Union[str, Tuple[int, int, int]]] = None, radius: int = 2, width: int = 3, @@ -318,7 +318,7 @@ def draw_keypoints( image (Tensor): Tensor of shape (3, H, W) and dtype uint8. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, in the format [x, y]. - connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, + connectivity (List[Tuple[int, int]]]): A List of tuple where, each tuple contains pair of keypoints to be connected. colors (str, Tuple): The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.