-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Preserve Datapoint subclasses instead of returning tensors #7807
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8f8f936
b1018a9
ddd88cd
4e8b53d
7471271
c5b44a9
23b9704
f12fee1
1962124
e9c1173
854b01c
ec17580
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2270,3 +2270,86 @@ def test_image_correctness(self, permutation, batch_dims): | |
| expected = self.reference_image_correctness(image, permutation=permutation) | ||
|
|
||
| torch.testing.assert_close(actual, expected) | ||
|
|
||
|
|
||
| def test_operations(): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is mostly for illustrating the new behaviour. If we're OK with it, I'll refactor this test into something a little more polished |
||
|
|
||
| img = datapoints.Image(torch.rand(3, 10, 10)) | ||
| t = torch.rand(3, 10, 10) | ||
| mask = datapoints.Mask(torch.rand(1, 10, 10)) | ||
|
|
||
| for out in ( | ||
| [ | ||
| img + t, | ||
| t + img, | ||
| img * t, | ||
| t * img, | ||
| img + 3, | ||
| 3 + img, | ||
| img * 3, | ||
| 3 * img, | ||
| img + img, | ||
| img.sum(), | ||
| img.reshape(-1), | ||
| img.float(), | ||
| torch.stack([img, img]), | ||
| ] | ||
| + list(torch.chunk(img, 2)) | ||
| + list(torch.unbind(img)) | ||
| ): | ||
| assert isinstance(out, datapoints.Image) | ||
|
|
||
| for out in ( | ||
| [ | ||
| mask + t, | ||
| t + mask, | ||
| mask * t, | ||
| t * mask, | ||
| mask + 3, | ||
| 3 + mask, | ||
| mask * 3, | ||
| 3 * mask, | ||
| mask + mask, | ||
| mask.sum(), | ||
| mask.reshape(-1), | ||
| mask.float(), | ||
| torch.stack([mask, mask]), | ||
| ] | ||
| + list(torch.chunk(mask, 2)) | ||
| + list(torch.unbind(mask)) | ||
| ): | ||
| assert isinstance(out, datapoints.Mask) | ||
|
|
||
| with pytest.raises(TypeError, match="unsupported operand type"): | ||
| img + mask | ||
|
Comment on lines
+2323
to
+2324
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users want to do that? Perfect, they'll need explicitly say what type they want as output by converting one of those operands to a tensor. We don't have to assume anything on their behalf and (surprisingly) return a pure tensor. EDIT: as @pmeier pointed out offline, this is in fact the same behaviour as on |
||
|
|
||
| with pytest.raises(TypeError, match="unsupported operand type"): | ||
| img * mask | ||
|
|
||
| bboxes = datapoints.BoundingBoxes( | ||
| [[17, 16, 344, 495], [0, 10, 0, 10]], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1000, 1000) | ||
| ) | ||
| t = torch.rand(2, 4) | ||
|
|
||
| for out in ( | ||
| [ | ||
| bboxes + t, | ||
| t + bboxes, | ||
| bboxes * t, | ||
| t * bboxes, | ||
| bboxes + 3, | ||
| 3 + bboxes, | ||
| bboxes * 3, | ||
| 3 * bboxes, | ||
| bboxes + bboxes, | ||
| bboxes.sum(), | ||
| bboxes.reshape(-1), | ||
| bboxes.float(), | ||
| torch.stack([bboxes, bboxes]), | ||
| ] | ||
| + list(torch.chunk(bboxes, 2)) | ||
| + list(torch.unbind(bboxes)) | ||
| ): | ||
| assert isinstance(out, datapoints.BoundingBoxes) | ||
| assert hasattr(out, "format") | ||
| assert hasattr(out, "canvas_size") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,7 +66,7 @@ def _copy_paste( | |
|
|
||
| # Copy-paste masks: | ||
| masks = masks * inverse_paste_alpha_mask | ||
| non_all_zero_masks = masks.sum((-1, -2)) > 0 | ||
| non_all_zero_masks = (masks.sum((-1, -2)) > 0).as_subclass(torch.Tensor) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was 2 other similar failures (below). The reason for the error is that This is the only kind of instance that I identified as potentially weird / confusing. But the error message is good enough to figure out the fix. |
||
| masks = masks[non_all_zero_masks] | ||
|
|
||
| # Do a shallow copy of the target dict | ||
|
|
@@ -92,7 +92,9 @@ def _copy_paste( | |
|
|
||
| # Check for degenerated boxes and remove them | ||
| boxes = F.convert_format_bounding_boxes( | ||
| out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY | ||
| out_target["boxes"].as_subclass(torch.Tensor), | ||
| old_format=bbox_format, | ||
| new_format=datapoints.BoundingBoxFormat.XYXY, | ||
| ) | ||
| degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] | ||
| if degenerate_boxes.any(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This (and similar changes below) was needed because
in_boxesis now still a BBox instance, andresized_crop_bounding_boxesexpects a tensor (there is an error saying something like "if you pass a bbox, don't pass the format").