@@ -7397,6 +7397,326 @@ def test_errors_functional(self):
73977397 F .sanitize_bounding_boxes (good_bbox .tolist ())
73987398
73997399
7400+ class TestSanitizeKeyPoints :
7401+ def _make_keypoints_with_validity (
7402+ self ,
7403+ canvas_size = (100 , 100 ),
7404+ shape = "2d" , # "2d", "3d", "4d" for different keypoint shapes
7405+ ):
7406+ """Create keypoints with known validity for testing."""
7407+ canvas_h , canvas_w = canvas_size
7408+
7409+ if shape == "2d" : # [N_points, 2]
7410+ keypoints_data = [
7411+ ([5 , 5 ], True ), # Valid point inside image
7412+ ([canvas_w - 6 , canvas_h - 6 ], True ), # Valid point near corner
7413+ ([canvas_w // 2 , canvas_h // 2 ], True ), # Valid point in center
7414+ ([- 1 , canvas_h // 2 ], False ), # Invalid: x < 0
7415+ ([canvas_w // 2 , - 1 ], False ), # Invalid: y < 0
7416+ ([canvas_w , canvas_h // 2 ], False ), # Invalid: x >= canvas_w
7417+ ([canvas_w // 2 , canvas_h ], False ), # Invalid: y >= canvas_h
7418+ ([0 , 0 ], True ), # Edge case: exactly on edge
7419+ ([canvas_w - 1 , canvas_h - 1 ], True ), # Edge case: exactly on edge
7420+ ]
7421+ points , validity = zip (* keypoints_data )
7422+ keypoints = torch .tensor (points , dtype = torch .float32 )
7423+
7424+ elif shape == "3d" : # [N_objects, N_points, 2]
7425+ # Create groups of keypoints with different validity patterns
7426+ keypoints_data = [
7427+ # Group 1: All points valid
7428+ ([[10 , 10 ], [20 , 20 ], [30 , 30 ]], True ),
7429+ # Group 2: One invalid point (should be removed if min_invalid_points=1)
7430+ ([[10 , 10 ], [20 , 20 ], [- 5 , 30 ]], False ),
7431+ # Group 3: All points invalid
7432+ ([[- 1 , - 1 ], [- 2 , - 2 ], [- 3 , - 3 ]], False ),
7433+ # Group 4: Mix of valid and invalid (depends on min_invalid_points)
7434+ ([[10 , 10 ], [- 1 , 20 ], [- 2 , 30 ]], False ),
7435+ ]
7436+ groups , validity = zip (* keypoints_data )
7437+ keypoints = torch .tensor (groups , dtype = torch .float32 )
7438+
7439+ elif shape == "4d" : # [N_objects, N_bones, 2, 2]
7440+ # Create bone-like structures (pairs of points)
7441+ keypoints_data = [
7442+ # Object 1: All bones valid
7443+ ([[[10 , 10 ], [15 , 15 ]], [[20 , 20 ], [25 , 25 ]]], True ),
7444+ # Object 2: One bone with invalid point
7445+ ([[[10 , 10 ], [15 , 15 ]], [[- 1 , 20 ], [25 , 25 ]]], False ),
7446+ # Object 3: All bones invalid
7447+ ([[[- 1 , - 1 ], [- 2 , - 2 ]], [[- 3 , - 3 ], [- 4 , - 4 ]]], False ),
7448+ ]
7449+ objects , validity = zip (* keypoints_data )
7450+ keypoints = torch .tensor (objects , dtype = torch .float32 )
7451+
7452+ else :
7453+ raise ValueError (f"Unsupported shape: { shape } " )
7454+
7455+ return keypoints , validity
7456+
7457+ @pytest .mark .parametrize ("shape" , ["2d" , "3d" , "4d" ])
7458+ @pytest .mark .parametrize ("input_type" , [torch .Tensor , tv_tensors .KeyPoints ])
7459+ def test_functional (self , shape , input_type ):
7460+ """Test the sanitize_keypoints functional interface."""
7461+
7462+ # Create inputs
7463+ canvas_size = (50 , 50 )
7464+ keypoints , expected_validity = self ._make_keypoints_with_validity (
7465+ canvas_size = canvas_size ,
7466+ shape = shape ,
7467+ )
7468+
7469+ if input_type is tv_tensors .KeyPoints :
7470+ keypoints = tv_tensors .KeyPoints (keypoints , canvas_size = canvas_size )
7471+ canvas_size_arg = None
7472+ else :
7473+ canvas_size_arg = canvas_size
7474+
7475+ # Apply function to be tested
7476+ result_keypoints , valid_mask = F .sanitize_keypoints (
7477+ keypoints ,
7478+ canvas_size = canvas_size_arg ,
7479+ )
7480+
7481+ # Check return types
7482+ assert isinstance (result_keypoints , input_type )
7483+ assert isinstance (valid_mask , torch .Tensor )
7484+ assert valid_mask .dtype == torch .bool
7485+
7486+ # Check that valid mask matches expected validity
7487+ assert_equal (valid_mask , torch .tensor (expected_validity ))
7488+
7489+ # Check that result has correct number of valid keypoints
7490+ assert result_keypoints .shape [0 ] == valid_mask .sum ().item ()
7491+
7492+ # Check that remaining keypoints shape is preserved
7493+ assert result_keypoints .shape [1 :] == keypoints .shape [1 :]
7494+
7495+ @pytest .mark .parametrize ("shape" , ["2d" , "3d" , "4d" ])
7496+ def test_kernel (self , shape ):
7497+ """Test kernel functionality."""
7498+ canvas_size = (30 , 30 )
7499+ keypoints , _ = self ._make_keypoints_with_validity (canvas_size = canvas_size , shape = shape )
7500+
7501+ check_kernel (
7502+ F .sanitize_keypoints ,
7503+ input = keypoints ,
7504+ canvas_size = canvas_size ,
7505+ check_batched_vs_unbatched = False , # This function doesn't support batching
7506+ )
7507+
7508+ @pytest .mark .parametrize ("shape" , ["2d" , "3d" , "4d" ])
7509+ @pytest .mark .parametrize (
7510+ "labels_getter" ,
7511+ (
7512+ "default" ,
7513+ lambda inputs : inputs ["labels" ],
7514+ lambda inputs : (inputs ["labels" ], inputs ["other_labels" ]),
7515+ lambda inputs : [inputs ["labels" ], inputs ["other_labels" ]],
7516+ None ,
7517+ lambda inputs : None ,
7518+ ),
7519+ )
7520+ @pytest .mark .parametrize ("sample_type" , (tuple , dict ))
7521+ def test_transform (self , shape , labels_getter , sample_type ):
7522+ """Test the SanitizeKeyPoints transform class."""
7523+ if sample_type is tuple and not isinstance (labels_getter , str ):
7524+ # Lambda-based labels_getter doesn't work with tuple input
7525+ return
7526+
7527+ canvas_size = (40 , 40 )
7528+ keypoints , expected_validity = self ._make_keypoints_with_validity (
7529+ canvas_size = canvas_size ,
7530+ shape = shape ,
7531+ )
7532+
7533+ keypoints = tv_tensors .KeyPoints (keypoints , canvas_size = canvas_size )
7534+ num_keypoints = keypoints .shape [0 ]
7535+
7536+ # Create associated labels and other data
7537+ labels = torch .arange (num_keypoints )
7538+ other_labels = torch .arange (num_keypoints ) * 2
7539+ masks = tv_tensors .Mask (torch .randint (0 , 2 , size = (num_keypoints , * canvas_size )))
7540+ whatever = torch .rand (10 )
7541+ input_img = torch .randint (0 , 256 , size = (1 , 3 , * canvas_size ), dtype = torch .uint8 )
7542+
7543+ sample = {
7544+ "image" : input_img ,
7545+ "labels" : labels ,
7546+ "keypoints" : keypoints ,
7547+ "other_labels" : other_labels ,
7548+ "whatever" : whatever ,
7549+ "None" : None ,
7550+ "masks" : masks ,
7551+ }
7552+
7553+ if sample_type is tuple :
7554+ img = sample .pop ("image" )
7555+ sample = (img , sample )
7556+
7557+ # Apply transform
7558+ transform = transforms .SanitizeKeyPoints (
7559+ labels_getter = labels_getter ,
7560+ )
7561+ out = transform (sample )
7562+
7563+ # Extract outputs
7564+ if sample_type is tuple :
7565+ out_image = out [0 ]
7566+ out_labels = out [1 ]["labels" ]
7567+ out_other_labels = out [1 ]["other_labels" ]
7568+ out_keypoints = out [1 ]["keypoints" ]
7569+ out_masks = out [1 ]["masks" ]
7570+ out_whatever = out [1 ]["whatever" ]
7571+ else :
7572+ out_image = out ["image" ]
7573+ out_labels = out ["labels" ]
7574+ out_other_labels = out ["other_labels" ]
7575+ out_keypoints = out ["keypoints" ]
7576+ out_masks = out ["masks" ]
7577+ out_whatever = out ["whatever" ]
7578+
7579+ # Verify unchanged elements
7580+ assert_equal (out_image , input_img )
7581+ assert_equal (out_whatever , whatever )
7582+ assert_equal (out_masks , masks )
7583+
7584+ # Verify types
7585+ assert isinstance (out_keypoints , tv_tensors .KeyPoints )
7586+ assert isinstance (out_masks , tv_tensors .Mask )
7587+
7588+ # Calculate expected valid indices
7589+ valid_indices = [i for i , is_valid in enumerate (expected_validity ) if is_valid ]
7590+
7591+ # Test label handling
7592+ if labels_getter is None or (callable (labels_getter ) and labels_getter (sample ) is None ):
7593+ # Labels should be unchanged
7594+ assert out_labels is labels
7595+ assert out_other_labels is other_labels
7596+ else :
7597+ # Labels should be filtered
7598+ assert isinstance (out_labels , torch .Tensor )
7599+ assert out_keypoints .shape [0 ] == out_labels .shape [0 ]
7600+ assert out_labels .tolist () == valid_indices
7601+
7602+ if callable (labels_getter ) and isinstance (labels_getter (sample ), (tuple , list )):
7603+ # other_labels should also be filtered
7604+ assert_equal (out_other_labels , out_labels * 2 ) # Since other_labels = labels * 2
7605+ else :
7606+ # other_labels and masks should be unchanged
7607+ assert_equal (out_other_labels , other_labels )
7608+
7609+ def test_edge_cases (self ):
7610+ """Test edge cases and boundary conditions."""
7611+ canvas_size = (10 , 10 )
7612+
7613+ # Test empty keypoints
7614+ empty_keypoints = tv_tensors .KeyPoints (torch .empty (0 , 2 ), canvas_size = canvas_size )
7615+ result , valid_mask = F .sanitize_keypoints (empty_keypoints )
7616+ print (empty_keypoints , result , valid_mask )
7617+ assert tuple (result .shape ) == (0 , 2 )
7618+ assert valid_mask .shape [0 ] == 0
7619+
7620+ # Test single valid keypoint
7621+ single_valid = tv_tensors .KeyPoints ([[5 , 5 ]], canvas_size = canvas_size )
7622+ result , valid_mask = F .sanitize_keypoints (single_valid )
7623+ assert tuple (result .shape ) == (1 , 2 )
7624+ assert valid_mask .all ()
7625+
7626+ # Test single invalid keypoint
7627+ single_invalid = tv_tensors .KeyPoints ([[- 1 , - 1 ]], canvas_size = canvas_size )
7628+ result , valid_mask = F .sanitize_keypoints (single_invalid )
7629+ assert tuple (result .shape ) == (0 , 2 )
7630+ assert not valid_mask .any ()
7631+
7632+ def test_errors_functional (self ):
7633+ """Test error conditions for the functional interface."""
7634+ good_keypoints = tv_tensors .KeyPoints ([[5 , 5 ]], canvas_size = (10 , 10 ))
7635+
7636+ # Test missing canvas_size for pure tensor
7637+ with pytest .raises (ValueError , match = "canvas_size cannot be None" ):
7638+ F .sanitize_keypoints (good_keypoints .as_subclass (torch .Tensor ), canvas_size = None )
7639+
7640+ # Test canvas_size provided for tv_tensor
7641+ with pytest .raises (ValueError , match = "canvas_size must be None" ):
7642+ F .sanitize_keypoints (good_keypoints , canvas_size = (10 , 10 ))
7643+
7644+ def test_errors_transform (self ):
7645+ """Test error conditions for the transform class."""
7646+ good_keypoints = tv_tensors .KeyPoints ([[5 , 5 ]], canvas_size = (10 , 10 ))
7647+
7648+ # Test invalid labels_getter
7649+ with pytest .raises (ValueError , match = "labels_getter should either be" ):
7650+ transforms .SanitizeKeyPoints (labels_getter = "invalid_type" ) # type: ignore
7651+
7652+ # Test missing labels key
7653+ with pytest .raises (ValueError , match = "Could not infer where the labels are" ):
7654+ bad_sample = {"keypoints" : good_keypoints , "BAD_KEY" : torch .tensor ([0 ])}
7655+ transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
7656+
7657+ # Test labels not a tensor
7658+ with pytest .raises (ValueError , match = "must be a tensor" ):
7659+ bad_sample = {"keypoints" : good_keypoints , "labels" : [0 ]}
7660+ transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
7661+
7662+ # Test mismatched sizes
7663+ with pytest .raises (ValueError , match = "Number of" ):
7664+ bad_sample = {"keypoints" : good_keypoints , "labels" : torch .tensor ([0 , 1 , 2 ])}
7665+ transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
7666+
7667+ def test_no_label (self ):
7668+ """Test transform without labels."""
7669+ img = make_image ()
7670+ keypoints = make_keypoints ()
7671+
7672+ # Should raise error without labels_getter=None
7673+ with pytest .raises (ValueError , match = "or a two-tuple whose second item is a dict" ):
7674+ transforms .SanitizeKeyPoints (labels_getter = "default" )(img , keypoints )
7675+
7676+ # Should work with labels_getter=None
7677+ out_img , out_keypoints = transforms .SanitizeKeyPoints (labels_getter = None )(img , keypoints )
7678+ assert isinstance (out_img , tv_tensors .Image )
7679+ assert isinstance (out_keypoints , tv_tensors .KeyPoints )
7680+
7681+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
7682+ def test_device_and_dtype_consistency (self , device ):
7683+ """Test that device and dtype are preserved."""
7684+ canvas_size = (20 , 20 )
7685+ keypoints = torch .tensor ([[5 , 5 ], [15 , 15 ], [- 1 , - 1 ]], dtype = torch .float32 , device = device )
7686+ keypoints = tv_tensors .KeyPoints (keypoints , canvas_size = canvas_size )
7687+
7688+ result , valid_mask = F .sanitize_keypoints (keypoints )
7689+
7690+ assert result .device == keypoints .device
7691+ assert result .dtype == keypoints .dtype
7692+ assert valid_mask .device == keypoints .device
7693+
7694+ def test_keypoint_shapes_consistency (self ):
7695+ """Test that different keypoint shapes are handled correctly."""
7696+ canvas_size = (50 , 50 )
7697+
7698+ # Test 2D shape [N_points, 2]
7699+ kp_2d = torch .tensor ([[10 , 10 ], [20 , 20 ], [- 1 , - 1 ]], dtype = torch .float32 )
7700+ kp_2d = tv_tensors .KeyPoints (kp_2d , canvas_size = canvas_size )
7701+ result_2d , valid_2d = F .sanitize_keypoints (kp_2d )
7702+ assert result_2d .ndim == 2
7703+ assert result_2d .shape [1 :] == kp_2d .shape [1 :]
7704+
7705+ # Test 3D shape [N_objects, N_points, 2]
7706+ kp_3d = torch .tensor ([[[10 , 10 ], [20 , 20 ]], [[- 1 , - 1 ], [30 , 30 ]]], dtype = torch .float32 )
7707+ kp_3d = tv_tensors .KeyPoints (kp_3d , canvas_size = canvas_size )
7708+ result_3d , valid_3d = F .sanitize_keypoints (kp_3d )
7709+ assert result_3d .ndim == 3
7710+ assert result_3d .shape [1 :] == kp_3d .shape [1 :]
7711+
7712+ # Test 4D shape [N_objects, N_bones, 2, 2]
7713+ kp_4d = torch .tensor ([[[[10 , 10 ], [20 , 20 ]]], [[[- 1 , - 1 ], [30 , 30 ]]]], dtype = torch .float32 )
7714+ kp_4d = tv_tensors .KeyPoints (kp_4d , canvas_size = canvas_size )
7715+ result_4d , valid_4d = F .sanitize_keypoints (kp_4d )
7716+ assert result_4d .ndim == 4
7717+ assert result_4d .shape [1 :] == kp_4d .shape [1 :]
7718+
7719+
74007720class TestJPEG :
74017721 @pytest .mark .parametrize ("quality" , [5 , 75 ])
74027722 @pytest .mark .parametrize ("color_space" , ["RGB" , "GRAY" ])
0 commit comments