-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Preserve Datapoint subclasses instead of returning tensors #7807
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7807
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 1 Unrelated FailureAs of commit ec17580: BROKEN TRUNK - The following job failed but were present on the merge base bf6a8dc:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| with pytest.raises(TypeError, match="unsupported operand type"): | ||
| img + mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users want to do that? Perfect, they'll need explicitly say what type they want as output by converting one of those operands to a tensor. We don't have to assume anything on their behalf and (surprisingly) return a pure tensor.
EDIT: as @pmeier pointed out offline, this is in fact the same behaviour as on main - nothing new
|
|
||
| output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size) | ||
| output_boxes, output_canvas_size = F.resized_crop_bounding_boxes( | ||
| in_boxes.as_subclass(torch.Tensor), format, top, left, height, width, size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This (and similar changes below) was needed because in_boxes is now still a BBox instance, and resized_crop_bounding_boxes expects a tensor (there is an error saying something like "if you pass a bbox, don't pass the format").
| # Copy-paste masks: | ||
| masks = masks * inverse_paste_alpha_mask | ||
| non_all_zero_masks = masks.sum((-1, -2)) > 0 | ||
| non_all_zero_masks = (masks.sum((-1, -2)) > 0).as_subclass(torch.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was 2 other similar failures (below). The reason for the error is that (masks.sum((-1, -2)) > 0) is still a Mask object, and we can't use Masks as indices (line below).
This is the only kind of instance that I identified as potentially weird / confusing. But the error message is good enough to figure out the fix.
(In contrast, unwrapping all the time is likely to cause a lot more surprises and forces users to re-wrap all the time).
| assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint | ||
|
|
||
|
|
||
| def test_operations(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is mostly for illustrating the new behaviour. If we're OK with it, I'll refactor this test into something a little more polished
|
Got superseded by #7825 |
This PR addresses the "subclass unwrapping" issue from #7319.
We now always preserve the Datapoint type when doing native operations like
img + 3orimg + some_tensor. This largely simplifies theDatapointclass implementation and avoid the potentially surprising "unwrapping" behaviour.BoundingBoxes is the only class that needs a special treatment as it requires metadata, so it's the only class for which we override
__torch_function__. Overall, the Datapoint logic is greatly simplified as it largely relies on the default ones fromtorch.Tensor.Take a look at the newly-added
test_operations()for an illustration of what is now possible.Note: following #7807 (comment), the unwrapping / rewrapping mechanism in our functionals is preserved for perf reasons only.