Skip to content

Conversation

@NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Aug 7, 2023

This PR addresses the "subclass unwrapping" issue from #7319.

We now always preserve the Datapoint type when doing native operations like img + 3 or img + some_tensor. This largely simplifies the Datapoint class implementation and avoid the potentially surprising "unwrapping" behaviour.

BoundingBoxes is the only class that needs a special treatment as it requires metadata, so it's the only class for which we override __torch_function__. Overall, the Datapoint logic is greatly simplified as it largely relies on the default ones from torch.Tensor.

Take a look at the newly-added test_operations() for an illustration of what is now possible.

Note: following #7807 (comment), the unwrapping / rewrapping mechanism in our functionals is preserved for perf reasons only.

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 7, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7807

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 1 Unrelated Failure

As of commit ec17580:

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base bf6a8dc:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Comment on lines +2297 to +2298
with pytest.raises(TypeError, match="unsupported operand type"):
img + mask
Copy link
Member Author

@NicolasHug NicolasHug Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users want to do that? Perfect, they'll need explicitly say what type they want as output by converting one of those operands to a tensor. We don't have to assume anything on their behalf and (surprisingly) return a pure tensor.

EDIT: as @pmeier pointed out offline, this is in fact the same behaviour as on main - nothing new


output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(
in_boxes.as_subclass(torch.Tensor), format, top, left, height, width, size
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and similar changes below) was needed because in_boxes is now still a BBox instance, and resized_crop_bounding_boxes expects a tensor (there is an error saying something like "if you pass a bbox, don't pass the format").

@NicolasHug NicolasHug changed the title Get rid of __torchfunction__ and the whole wrapping/unwrapping logic Get rid of __torchfunction__ Aug 8, 2023
# Copy-paste masks:
masks = masks * inverse_paste_alpha_mask
non_all_zero_masks = masks.sum((-1, -2)) > 0
non_all_zero_masks = (masks.sum((-1, -2)) > 0).as_subclass(torch.Tensor)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was 2 other similar failures (below). The reason for the error is that (masks.sum((-1, -2)) > 0) is still a Mask object, and we can't use Masks as indices (line below).

This is the only kind of instance that I identified as potentially weird / confusing. But the error message is good enough to figure out the fix.
(In contrast, unwrapping all the time is likely to cause a lot more surprises and forces users to re-wrap all the time).

assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint


def test_operations():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is mostly for illustrating the new behaviour. If we're OK with it, I'll refactor this test into something a little more polished

@NicolasHug NicolasHug changed the title Get rid of __torchfunction__ Preserve Datapoint subclasses instead of returning tensors Aug 9, 2023
@NicolasHug NicolasHug marked this pull request as ready for review August 9, 2023 16:36
@NicolasHug
Copy link
Member Author

Got superseded by #7825

@NicolasHug NicolasHug closed this Aug 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants