Skip to content

Commit 8e1134b

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Add ToPureTensor transform (#7823)
Summary: Co-authored-by: Philip Meier <[email protected]> Reviewed By: matteobettini Differential Revision: D48642260 fbshipit-source-id: c8f287816cc22508274c492703b6938fade169ad
1 parent 28ee0f1 commit 8e1134b

File tree

7 files changed

+53
-1
lines changed

7 files changed

+53
-1
lines changed

docs/source/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ Conversion
237237
v2.ConvertImageDtype
238238
v2.ToDtype
239239
v2.ConvertBoundingBoxFormat
240+
v2.ToPureTensor
240241

241242
Auto-Augmentation
242243
-----------------

references/classification/presets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def __init__(
6868
if random_erase_prob > 0:
6969
transforms.append(T.RandomErasing(p=random_erase_prob))
7070

71+
if use_v2:
72+
transforms.append(T.ToPureTensor())
73+
7174
self.transforms = T.Compose(transforms)
7275

7376
def __call__(self, img):
@@ -107,6 +110,9 @@ def __init__(
107110
T.Normalize(mean=mean, std=std),
108111
]
109112

113+
if use_v2:
114+
transforms.append(T.ToPureTensor())
115+
110116
self.transforms = T.Compose(transforms)
111117

112118
def __call__(self, img):

references/detection/presets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
transforms += [
8080
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
8181
T.SanitizeBoundingBoxes(),
82+
T.ToPureTensor(),
8283
]
8384

8485
self.transforms = T.Compose(transforms)
@@ -103,6 +104,10 @@ def __init__(self, backend="pil", use_v2=False):
103104
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
104105

105106
transforms += [T.ConvertImageDtype(torch.float)]
107+
108+
if use_v2:
109+
transforms += [T.ToPureTensor()]
110+
106111
self.transforms = T.Compose(transforms)
107112

108113
def __call__(self, img, target):

references/segmentation/presets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def __init__(
6363
transforms += [T.ConvertImageDtype(torch.float)]
6464

6565
transforms += [T.Normalize(mean=mean, std=std)]
66+
if use_v2:
67+
transforms += [T.ToPureTensor()]
6668

6769
self.transforms = T.Compose(transforms)
6870

@@ -98,6 +100,9 @@ def __init__(
98100
T.ConvertImageDtype(torch.float),
99101
T.Normalize(mean=mean, std=std),
100102
]
103+
if use_v2:
104+
transforms += [T.ToPureTensor()]
105+
101106
self.transforms = T.Compose(transforms)
102107

103108
def __call__(self, img, target):

test/test_transforms_v2_refactored.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,3 +2353,24 @@ def test_displacement_error(self, make_input):
23532353
@pytest.mark.parametrize("device", cpu_and_cuda())
23542354
def test_transform(self, make_input, size, device):
23552355
check_transform(transforms.ElasticTransform, make_input(size, device=device))
2356+
2357+
2358+
class TestToPureTensor:
2359+
def test_correctness(self):
2360+
input = {
2361+
"img": make_image(),
2362+
"img_tensor": make_image_tensor(),
2363+
"img_pil": make_image_pil(),
2364+
"mask": make_detection_mask(),
2365+
"video": make_video(),
2366+
"bbox": make_bounding_box(),
2367+
"str": "str",
2368+
}
2369+
2370+
out = transforms.ToPureTensor()(input)
2371+
2372+
for input_value, out_value in zip(input.values(), out.values()):
2373+
if isinstance(input_value, datapoints.Datapoint):
2374+
assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, datapoints.Datapoint)
2375+
else:
2376+
assert isinstance(out_value, type(input_value))

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
ToDtype,
5353
)
5454
from ._temporal import UniformTemporalSubsample
55-
from ._type_conversion import PILToTensor, ToImage, ToPILImage
55+
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
5656

5757
from ._deprecated import ToTensor # usort: skip
5858

torchvision/transforms/v2/_type_conversion.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,17 @@ def _transform(
7575
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
7676
) -> PIL.Image.Image:
7777
return F.to_pil_image(inpt, mode=self.mode)
78+
79+
80+
class ToPureTensor(Transform):
81+
"""[BETA] Convert all datapoints to pure tensors, removing associated metadata (if any).
82+
83+
.. v2betastatus:: ToPureTensor transform
84+
85+
This doesn't scale or change the values, only the type.
86+
"""
87+
88+
_transformed_types = (datapoints.Datapoint,)
89+
90+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
91+
return inpt.as_subclass(torch.Tensor)

0 commit comments

Comments
 (0)