11import itertools
2- import pathlib
3- import pickle
42import random
53
64import numpy as np
119import torchvision .transforms .v2 as transforms
1210
1311from common_utils import assert_equal , cpu_and_cuda
14- from torch .utils ._pytree import tree_flatten , tree_unflatten
1512from torchvision import tv_tensors
1613from torchvision .ops .boxes import box_iou
1714from torchvision .transforms .functional import to_pil_image
18- from torchvision .transforms .v2 import functional as F
19- from torchvision .transforms .v2 ._utils import check_type , is_pure_tensor , query_chw
20- from transforms_v2_legacy_utils import (
21- make_bounding_boxes ,
22- make_detection_mask ,
23- make_image ,
24- make_images ,
25- make_multiple_bounding_boxes ,
26- make_segmentation_mask ,
27- make_video ,
28- make_videos ,
29- )
15+ from torchvision .transforms .v2 ._utils import is_pure_tensor
16+ from transforms_v2_legacy_utils import make_bounding_boxes , make_detection_mask , make_image , make_images , make_videos
3017
3118
3219def make_vanilla_tensor_images (* args , ** kwargs ):
@@ -41,11 +28,6 @@ def make_pil_images(*args, **kwargs):
4128 yield to_pil_image (image )
4229
4330
44- def make_vanilla_tensor_bounding_boxes (* args , ** kwargs ):
45- for bounding_boxes in make_multiple_bounding_boxes (* args , ** kwargs ):
46- yield bounding_boxes .data
47-
48-
4931def parametrize (transforms_with_inputs ):
5032 return pytest .mark .parametrize (
5133 ("transform" , "input" ),
@@ -61,218 +43,6 @@ def parametrize(transforms_with_inputs):
6143 )
6244
6345
64- def auto_augment_adapter (transform , input , device ):
65- adapted_input = {}
66- image_or_video_found = False
67- for key , value in input .items ():
68- if isinstance (value , (tv_tensors .BoundingBoxes , tv_tensors .Mask )):
69- # AA transforms don't support bounding boxes or masks
70- continue
71- elif check_type (value , (tv_tensors .Image , tv_tensors .Video , is_pure_tensor , PIL .Image .Image )):
72- if image_or_video_found :
73- # AA transforms only support a single image or video
74- continue
75- image_or_video_found = True
76- adapted_input [key ] = value
77- return adapted_input
78-
79-
80- def linear_transformation_adapter (transform , input , device ):
81- flat_inputs = list (input .values ())
82- c , h , w = query_chw (
83- [
84- item
85- for item , needs_transform in zip (flat_inputs , transforms .Transform ()._needs_transform_list (flat_inputs ))
86- if needs_transform
87- ]
88- )
89- num_elements = c * h * w
90- transform .transformation_matrix = torch .randn ((num_elements , num_elements ), device = device )
91- transform .mean_vector = torch .randn ((num_elements ,), device = device )
92- return {key : value for key , value in input .items () if not isinstance (value , PIL .Image .Image )}
93-
94-
95- def normalize_adapter (transform , input , device ):
96- adapted_input = {}
97- for key , value in input .items ():
98- if isinstance (value , PIL .Image .Image ):
99- # normalize doesn't support PIL images
100- continue
101- elif check_type (value , (tv_tensors .Image , tv_tensors .Video , is_pure_tensor )):
102- # normalize doesn't support integer images
103- value = F .to_dtype (value , torch .float32 , scale = True )
104- adapted_input [key ] = value
105- return adapted_input
106-
107-
108- class TestSmoke :
109- @pytest .mark .parametrize (
110- ("transform" , "adapter" ),
111- [
112- (transforms .RandomErasing (p = 1.0 ), None ),
113- (transforms .AugMix (), auto_augment_adapter ),
114- (transforms .AutoAugment (), auto_augment_adapter ),
115- (transforms .RandAugment (), auto_augment_adapter ),
116- (transforms .TrivialAugmentWide (), auto_augment_adapter ),
117- (transforms .ColorJitter (brightness = 0.1 , contrast = 0.2 , saturation = 0.3 , hue = 0.15 ), None ),
118- (transforms .RandomAdjustSharpness (sharpness_factor = 0.5 , p = 1.0 ), None ),
119- (transforms .RandomAutocontrast (p = 1.0 ), None ),
120- (transforms .RandomEqualize (p = 1.0 ), None ),
121- (transforms .RandomInvert (p = 1.0 ), None ),
122- (transforms .RandomChannelPermutation (), None ),
123- (transforms .RandomPosterize (bits = 4 , p = 1.0 ), None ),
124- (transforms .RandomSolarize (threshold = 0.5 , p = 1.0 ), None ),
125- (transforms .CenterCrop ([16 , 16 ]), None ),
126- (transforms .ElasticTransform (sigma = 1.0 ), None ),
127- (transforms .Pad (4 ), None ),
128- (transforms .RandomAffine (degrees = 30.0 ), None ),
129- (transforms .RandomCrop ([16 , 16 ], pad_if_needed = True ), None ),
130- (transforms .RandomHorizontalFlip (p = 1.0 ), None ),
131- (transforms .RandomPerspective (p = 1.0 ), None ),
132- (transforms .RandomResize (min_size = 10 , max_size = 20 , antialias = True ), None ),
133- (transforms .RandomResizedCrop ([16 , 16 ], antialias = True ), None ),
134- (transforms .RandomRotation (degrees = 30 ), None ),
135- (transforms .RandomShortestSize (min_size = 10 , antialias = True ), None ),
136- (transforms .RandomVerticalFlip (p = 1.0 ), None ),
137- (transforms .Resize ([16 , 16 ], antialias = True ), None ),
138- (transforms .ScaleJitter ((16 , 16 ), scale_range = (0.8 , 1.2 ), antialias = True ), None ),
139- (transforms .ClampBoundingBoxes (), None ),
140- (transforms .ConvertBoundingBoxFormat (tv_tensors .BoundingBoxFormat .CXCYWH ), None ),
141- (transforms .ConvertImageDtype (), None ),
142- (transforms .GaussianBlur (kernel_size = 3 ), None ),
143- (
144- transforms .LinearTransformation (
145- # These are just dummy values that will be filled by the adapter. We can't define them upfront,
146- # because for we neither know the spatial size nor the device at this point
147- transformation_matrix = torch .empty ((1 , 1 )),
148- mean_vector = torch .empty ((1 ,)),
149- ),
150- linear_transformation_adapter ,
151- ),
152- (transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]), normalize_adapter ),
153- (transforms .ToDtype (torch .float64 ), None ),
154- (transforms .UniformTemporalSubsample (num_samples = 2 ), None ),
155- ],
156- ids = lambda transform : type (transform ).__name__ ,
157- )
158- @pytest .mark .parametrize ("container_type" , [dict , list , tuple ])
159- @pytest .mark .parametrize (
160- "image_or_video" ,
161- [
162- make_image (),
163- make_video (),
164- next (make_pil_images (color_spaces = ["RGB" ])),
165- next (make_vanilla_tensor_images ()),
166- ],
167- )
168- @pytest .mark .parametrize ("de_serialize" , [lambda t : t , lambda t : pickle .loads (pickle .dumps (t ))])
169- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
170- def test_common (self , transform , adapter , container_type , image_or_video , de_serialize , device ):
171- transform = de_serialize (transform )
172-
173- canvas_size = F .get_size (image_or_video )
174- input = dict (
175- image_or_video = image_or_video ,
176- image_tv_tensor = make_image (size = canvas_size ),
177- video_tv_tensor = make_video (size = canvas_size ),
178- image_pil = next (make_pil_images (sizes = [canvas_size ], color_spaces = ["RGB" ])),
179- bounding_boxes_xyxy = make_bounding_boxes (
180- format = tv_tensors .BoundingBoxFormat .XYXY , canvas_size = canvas_size , batch_dims = (3 ,)
181- ),
182- bounding_boxes_xywh = make_bounding_boxes (
183- format = tv_tensors .BoundingBoxFormat .XYWH , canvas_size = canvas_size , batch_dims = (4 ,)
184- ),
185- bounding_boxes_cxcywh = make_bounding_boxes (
186- format = tv_tensors .BoundingBoxFormat .CXCYWH , canvas_size = canvas_size , batch_dims = (5 ,)
187- ),
188- bounding_boxes_degenerate_xyxy = tv_tensors .BoundingBoxes (
189- [
190- [0 , 0 , 0 , 0 ], # no height or width
191- [0 , 0 , 0 , 1 ], # no height
192- [0 , 0 , 1 , 0 ], # no width
193- [2 , 0 , 1 , 1 ], # x1 > x2, y1 < y2
194- [0 , 2 , 1 , 1 ], # x1 < x2, y1 > y2
195- [2 , 2 , 1 , 1 ], # x1 > x2, y1 > y2
196- ],
197- format = tv_tensors .BoundingBoxFormat .XYXY ,
198- canvas_size = canvas_size ,
199- ),
200- bounding_boxes_degenerate_xywh = tv_tensors .BoundingBoxes (
201- [
202- [0 , 0 , 0 , 0 ], # no height or width
203- [0 , 0 , 0 , 1 ], # no height
204- [0 , 0 , 1 , 0 ], # no width
205- [0 , 0 , 1 , - 1 ], # negative height
206- [0 , 0 , - 1 , 1 ], # negative width
207- [0 , 0 , - 1 , - 1 ], # negative height and width
208- ],
209- format = tv_tensors .BoundingBoxFormat .XYWH ,
210- canvas_size = canvas_size ,
211- ),
212- bounding_boxes_degenerate_cxcywh = tv_tensors .BoundingBoxes (
213- [
214- [0 , 0 , 0 , 0 ], # no height or width
215- [0 , 0 , 0 , 1 ], # no height
216- [0 , 0 , 1 , 0 ], # no width
217- [0 , 0 , 1 , - 1 ], # negative height
218- [0 , 0 , - 1 , 1 ], # negative width
219- [0 , 0 , - 1 , - 1 ], # negative height and width
220- ],
221- format = tv_tensors .BoundingBoxFormat .CXCYWH ,
222- canvas_size = canvas_size ,
223- ),
224- detection_mask = make_detection_mask (size = canvas_size ),
225- segmentation_mask = make_segmentation_mask (size = canvas_size ),
226- int = 0 ,
227- float = 0.0 ,
228- bool = True ,
229- none = None ,
230- str = "str" ,
231- path = pathlib .Path .cwd (),
232- object = object (),
233- tensor = torch .empty (5 ),
234- array = np .empty (5 ),
235- )
236- if adapter is not None :
237- input = adapter (transform , input , device )
238-
239- if container_type in {tuple , list }:
240- input = container_type (input .values ())
241-
242- input_flat , input_spec = tree_flatten (input )
243- input_flat = [item .to (device ) if isinstance (item , torch .Tensor ) else item for item in input_flat ]
244- input = tree_unflatten (input_flat , input_spec )
245-
246- torch .manual_seed (0 )
247- output = transform (input )
248- output_flat , output_spec = tree_flatten (output )
249-
250- assert output_spec == input_spec
251-
252- for output_item , input_item , should_be_transformed in zip (
253- output_flat , input_flat , transforms .Transform ()._needs_transform_list (input_flat )
254- ):
255- if should_be_transformed :
256- assert type (output_item ) is type (input_item )
257- else :
258- assert output_item is input_item
259-
260- if isinstance (input_item , tv_tensors .BoundingBoxes ) and not isinstance (
261- transform , transforms .ConvertBoundingBoxFormat
262- ):
263- assert output_item .format == input_item .format
264-
265- # Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
266- # transform that does this), back into a valid one.
267- # TODO: we should test that against all degenerate boxes above
268- for format in list (tv_tensors .BoundingBoxFormat ):
269- sample = dict (
270- boxes = tv_tensors .BoundingBoxes ([[0 , 0 , 0 , 0 ]], format = format , canvas_size = (224 , 244 )),
271- labels = torch .tensor ([3 ]),
272- )
273- assert transforms .SanitizeBoundingBoxes ()(sample )["boxes" ].shape == (0 , 4 )
274-
275-
27646@pytest .mark .parametrize (
27747 "flat_inputs" ,
27848 itertools .permutations (
@@ -543,39 +313,6 @@ def test__get_params(self, min_size, max_size):
543313 assert shorter in min_size
544314
545315
546- class TestLinearTransformation :
547- def test_assertions (self ):
548- with pytest .raises (ValueError , match = "transformation_matrix should be square" ):
549- transforms .LinearTransformation (torch .rand (2 , 3 ), torch .rand (5 ))
550-
551- with pytest .raises (ValueError , match = "mean_vector should have the same length" ):
552- transforms .LinearTransformation (torch .rand (3 , 3 ), torch .rand (5 ))
553-
554- @pytest .mark .parametrize (
555- "inpt" ,
556- [
557- 122 * torch .ones (1 , 3 , 8 , 8 ),
558- 122.0 * torch .ones (1 , 3 , 8 , 8 ),
559- tv_tensors .Image (122 * torch .ones (1 , 3 , 8 , 8 )),
560- PIL .Image .new ("RGB" , (8 , 8 ), (122 , 122 , 122 )),
561- ],
562- )
563- def test__transform (self , inpt ):
564-
565- v = 121 * torch .ones (3 * 8 * 8 )
566- m = torch .ones (3 * 8 * 8 , 3 * 8 * 8 )
567- transform = transforms .LinearTransformation (m , v )
568-
569- if isinstance (inpt , PIL .Image .Image ):
570- with pytest .raises (TypeError , match = "does not support PIL images" ):
571- transform (inpt )
572- else :
573- output = transform (inpt )
574- assert isinstance (output , torch .Tensor )
575- assert output .unique () == 3 * 8 * 8
576- assert output .dtype == inpt .dtype
577-
578-
579316class TestRandomResize :
580317 def test__get_params (self ):
581318 min_size = 3
0 commit comments