-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Add gallery example for drawing keypoints #4892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e6a7275
ca28bbf
94a59b5
b3baf73
67395c0
d965ac9
0370108
47c3d9f
71584c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,10 +4,10 @@ | |||||||||||
| ======================= | ||||||||||||
|
|
||||||||||||
| This example illustrates some of the utilities that torchvision offers for | ||||||||||||
| visualizing images, bounding boxes, and segmentation masks. | ||||||||||||
| visualizing images, bounding boxes, segmentation masks and keypoints. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| # sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail.png" | ||||||||||||
| # sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail2.png" | ||||||||||||
|
|
||||||||||||
| import torch | ||||||||||||
| import numpy as np | ||||||||||||
|
|
@@ -366,3 +366,110 @@ def show(imgs): | |||||||||||
| # The two 'people' masks in the first image where not selected because they have | ||||||||||||
| # a lower score than the score threshold. Similarly in the second image, the | ||||||||||||
| # instance with class 15 (which corresponds to 'bench') was not selected. | ||||||||||||
|
|
||||||||||||
| ##################################### | ||||||||||||
| # Visualizing keypoints | ||||||||||||
| # ------------------------------ | ||||||||||||
| # The :func:`~torchvision.utils.draw_keypoints` function can be used to | ||||||||||||
| # draw keypoints on images. We will see how to use it with | ||||||||||||
| # torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`. | ||||||||||||
| # We will first have a look at output of the model. | ||||||||||||
| # | ||||||||||||
| # Note that the keypoint detection model does not need normalized images. | ||||||||||||
| # | ||||||||||||
|
|
||||||||||||
| from torchvision.models.detection import keypointrcnn_resnet50_fpn | ||||||||||||
| from torchvision.io import read_image | ||||||||||||
|
|
||||||||||||
| person_int = read_image(str(Path("assets") / "person1.jpg")) | ||||||||||||
| person_float = convert_image_dtype(person_int, dtype=torch.float) | ||||||||||||
|
|
||||||||||||
| model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False) | ||||||||||||
| model = model.eval() | ||||||||||||
|
|
||||||||||||
| outputs = model([person_float]) | ||||||||||||
| print(outputs) | ||||||||||||
|
|
||||||||||||
| ##################################### | ||||||||||||
| # As we see the output contains a list of dictionaries. | ||||||||||||
| # The output list is of length batch_size. | ||||||||||||
| # We currently have just a single image so length of list is 1. | ||||||||||||
| # Each entry in the list corresponds to an input image, | ||||||||||||
| # and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`. | ||||||||||||
| # Each value associated to those keys has `num_instances` elements in it. | ||||||||||||
| # In our case above there are 2 instances detected in the image. | ||||||||||||
|
|
||||||||||||
| kpts = outputs[0]['keypoints'] | ||||||||||||
| scores = outputs[0]['scores'] | ||||||||||||
|
|
||||||||||||
| print(kpts) | ||||||||||||
| print(scores) | ||||||||||||
|
|
||||||||||||
| ##################################### | ||||||||||||
| # The KeypointRCNN model detects there are two instances in the image. | ||||||||||||
| # If you plot the boxes by using :func:`~draw_bounding_boxes` | ||||||||||||
| # you would recognize they are the person and the surfboard. | ||||||||||||
| # If we look at the scores, we will realize that the model is much more confident about the person than surfboard. | ||||||||||||
| # We could now set a threshold confidence and plot instances which we are confident enough. | ||||||||||||
| # Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person. | ||||||||||||
|
|
||||||||||||
| detect_threshold = 0.75 | ||||||||||||
| idx = torch.where(scores > detect_threshold) | ||||||||||||
| keypoints = kpts[idx] | ||||||||||||
|
|
||||||||||||
| print(keypoints) | ||||||||||||
|
|
||||||||||||
| ##################################### | ||||||||||||
| # Great, now we have the keypoints corresponding to the person. | ||||||||||||
| # Each keypoint is represented by x, y coordinates and the visibility. | ||||||||||||
| # We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints. | ||||||||||||
| # Note that the utility expects uint8 images. | ||||||||||||
|
|
||||||||||||
| from torchvision.utils import draw_keypoints | ||||||||||||
|
|
||||||||||||
| res = draw_keypoints(person_int, keypoints, colors="blue", radius=3) | ||||||||||||
| show(res) | ||||||||||||
|
|
||||||||||||
| ##################################### | ||||||||||||
| # As we see the keypoints appear as colored circles over the image. | ||||||||||||
| # The coco keypoints for a person are ordered and represent the following list.\ | ||||||||||||
|
|
||||||||||||
| coco_keypoints = [ | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this will be great when people can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although it isn't used in the code. I made this as a code block so that it is copy-pastable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes agreed. We've already added it on the new API:
vision/torchvision/prototype/models/_meta.py Lines 1106 to 1109 in 9a7dc1a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess once the new prototype models and transforms are moved to stable, we would need to revisit the gallery examples. |
||||||||||||
| "nose", "left_eye", "right_eye", "left_ear", "right_ear", | ||||||||||||
| "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", | ||||||||||||
| "left_wrist", "right_wrist", "left_hip", "right_hip", | ||||||||||||
| "left_knee", "right_knee", "left_ankle", "right_ankle", | ||||||||||||
| ] | ||||||||||||
|
|
||||||||||||
| ##################################### | ||||||||||||
| # What if we are interested in joining the keypoints? | ||||||||||||
| # This is especially useful in creating pose detection or action recognition. | ||||||||||||
| # We can join the keypoints easily using the `connectivity` parameter. | ||||||||||||
| # A close observation would reveal that we would need to join the points in below | ||||||||||||
| # order to construct human skeleton. | ||||||||||||
| # | ||||||||||||
| # nose -> left_eye -> left_ear. (0, 1), (1, 3) | ||||||||||||
| # | ||||||||||||
| # nose -> right_eye -> right_ear. (0, 2), (2, 4) | ||||||||||||
| # | ||||||||||||
| # nose -> left_shoulder -> left_elbow -> left_wrist. (0, 5), (5, 7), (7, 9) | ||||||||||||
| # | ||||||||||||
| # nose -> right_shoulder -> right_elbow -> right_wrist. (0, 6), (6, 8), (8, 10) | ||||||||||||
| # | ||||||||||||
| # left_shoulder -> left_hip -> left_knee -> left_ankle. (5, 11), (11, 13), (13, 15) | ||||||||||||
| # | ||||||||||||
| # right_shoulder -> right_hip -> right_knee -> right_ankle. (6, 12), (12, 14), (14, 16) | ||||||||||||
| # | ||||||||||||
| # We will create a list containing these keypoint ids to be connected. | ||||||||||||
|
|
||||||||||||
| connect_skeleton = [ | ||||||||||||
| (0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8), | ||||||||||||
| (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16) | ||||||||||||
| ] | ||||||||||||
|
|
||||||||||||
| ##################################### | ||||||||||||
| # We pass the above list to the connectivity parameter to connect the keypoints. | ||||||||||||
| # | ||||||||||||
|
|
||||||||||||
| res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3) | ||||||||||||
| show(res) | ||||||||||||
Uh oh!
There was an error while loading. Please reload this page.