@@ -200,6 +200,30 @@ def horizontal_flip_bounding_box():
200200 yield SampleInput (bounding_box , format = bounding_box .format , image_size = bounding_box .image_size )
201201
202202
203+ @register_kernel_info_from_sample_inputs_fn
204+ def horizontal_flip_segmentation_mask ():
205+ for mask in make_segmentation_masks ():
206+ yield SampleInput (mask )
207+
208+
209+ @register_kernel_info_from_sample_inputs_fn
210+ def vertical_flip_image_tensor ():
211+ for image in make_images ():
212+ yield SampleInput (image )
213+
214+
215+ @register_kernel_info_from_sample_inputs_fn
216+ def vertical_flip_bounding_box ():
217+ for bounding_box in make_bounding_boxes (formats = [features .BoundingBoxFormat .XYXY ]):
218+ yield SampleInput (bounding_box , format = bounding_box .format , image_size = bounding_box .image_size )
219+
220+
221+ @register_kernel_info_from_sample_inputs_fn
222+ def vertical_flip_segmentation_mask ():
223+ for mask in make_segmentation_masks ():
224+ yield SampleInput (mask )
225+
226+
203227@register_kernel_info_from_sample_inputs_fn
204228def resize_image_tensor ():
205229 for image , interpolation , max_size , antialias in itertools .product (
@@ -404,9 +428,17 @@ def crop_segmentation_mask():
404428
405429
406430@register_kernel_info_from_sample_inputs_fn
407- def vertical_flip_segmentation_mask ():
408- for mask in make_segmentation_masks ():
409- yield SampleInput (mask )
431+ def resized_crop_image_tensor ():
432+ for mask , top , left , height , width , size , antialias in itertools .product (
433+ make_images (),
434+ [- 8 , 9 ],
435+ [- 8 , 9 ],
436+ [12 ],
437+ [12 ],
438+ [(16 , 18 )],
439+ [True , False ],
440+ ):
441+ yield SampleInput (mask , top = top , left = left , height = height , width = width , size = size , antialias = antialias )
410442
411443
412444@register_kernel_info_from_sample_inputs_fn
@@ -457,6 +489,19 @@ def pad_bounding_box():
457489 yield SampleInput (bounding_box , padding = padding , format = bounding_box .format )
458490
459491
492+ @register_kernel_info_from_sample_inputs_fn
493+ def perspective_image_tensor ():
494+ for image , perspective_coeffs , fill in itertools .product (
495+ make_images (extra_dims = ((), (4 ,))),
496+ [
497+ [1.2405 , 0.1772 , - 6.9113 , 0.0463 , 1.251 , - 5.235 , 0.00013 , 0.0018 ],
498+ [0.7366 , - 0.11724 , 1.45775 , - 0.15012 , 0.73406 , 2.6019 , - 0.0072 , - 0.0063 ],
499+ ],
500+ [None , [128 ], [12.0 ]], # fill
501+ ):
502+ yield SampleInput (image , perspective_coeffs = perspective_coeffs , fill = fill )
503+
504+
460505@register_kernel_info_from_sample_inputs_fn
461506def perspective_bounding_box ():
462507 for bounding_box , perspective_coeffs in itertools .product (
@@ -488,6 +533,15 @@ def perspective_segmentation_mask():
488533 )
489534
490535
536+ @register_kernel_info_from_sample_inputs_fn
537+ def center_crop_image_tensor ():
538+ for mask , output_size in itertools .product (
539+ make_images (sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 ))),
540+ [[4 , 3 ], [42 , 70 ], [4 ]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
541+ ):
542+ yield SampleInput (mask , output_size )
543+
544+
491545@register_kernel_info_from_sample_inputs_fn
492546def center_crop_bounding_box ():
493547 for bounding_box , output_size in itertools .product (make_bounding_boxes (), [(24 , 12 ), [16 , 18 ], [46 , 48 ], [12 ]]):
@@ -1181,6 +1235,18 @@ def _compute_expected_mask(mask, top_, left_, height_, width_):
11811235 torch .testing .assert_close (output_mask , expected_mask )
11821236
11831237
1238+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1239+ def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input (device ):
1240+ mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
1241+ mask [:, :, 0 ] = 1
1242+
1243+ out_mask = F .horizontal_flip_segmentation_mask (mask )
1244+
1245+ expected_mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
1246+ expected_mask [:, :, - 1 ] = 1
1247+ torch .testing .assert_close (out_mask , expected_mask )
1248+
1249+
11841250@pytest .mark .parametrize ("device" , cpu_and_gpu ())
11851251def test_correctness_vertical_flip_segmentation_mask_on_fixed_input (device ):
11861252 mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
0 commit comments