11import enum
22import inspect
3+ import random
4+ from collections import defaultdict
35from importlib .machinery import SourceFileLoader
46from pathlib import Path
57
1618 make_image ,
1719 make_images ,
1820 make_label ,
21+ make_segmentation_mask ,
1922)
2023from torchvision import transforms as legacy_transforms
2124from torchvision ._utils import sequence_to_str
2225from torchvision .prototype import features , transforms as prototype_transforms
26+ from torchvision .prototype .transforms import functional as F
27+ from torchvision .prototype .transforms ._utils import query_chw
2328from torchvision .prototype .transforms .functional import to_image_pil
2429
25-
2630DEFAULT_MAKE_IMAGES_KWARGS = dict (color_spaces = [features .ColorSpace .RGB ], extra_dims = [(4 ,)])
2731
2832
@@ -852,10 +856,12 @@ def test_aa(self, inpt, interpolation):
852856 assert_equal (expected_output , output )
853857
854858
855- # Import reference detection transforms here for consistency checks
856- # torchvision/references/detection/transforms.py
857- ref_det_filepath = Path (__file__ ).parent .parent / "references" / "detection" / "transforms.py"
858- det_transforms = SourceFileLoader (ref_det_filepath .stem , ref_det_filepath .as_posix ()).load_module ()
859+ def import_transforms_from_references (reference ):
860+ ref_det_filepath = Path (__file__ ).parent .parent / "references" / reference / "transforms.py"
861+ return SourceFileLoader (ref_det_filepath .stem , ref_det_filepath .as_posix ()).load_module ()
862+
863+
864+ det_transforms = import_transforms_from_references ("detection" )
859865
860866
861867class TestRefDetTransforms :
@@ -873,7 +879,7 @@ def make_datapoints(self, with_mask=True):
873879
874880 yield (pil_image , target )
875881
876- tensor_image = torch .randint ( 0 , 256 , size = ( 3 , * size ), dtype = torch . uint8 )
882+ tensor_image = torch .Tensor ( make_image ( size = size , color_space = features . ColorSpace . RGB ) )
877883 target = {
878884 "boxes" : make_bounding_box (image_size = size , format = "XYXY" , extra_dims = (num_objects ,), dtype = torch .float ),
879885 "labels" : make_label (extra_dims = (num_objects ,), categories = 80 ),
@@ -883,7 +889,7 @@ def make_datapoints(self, with_mask=True):
883889
884890 yield (tensor_image , target )
885891
886- feature_image = features . Image ( torch . randint ( 0 , 256 , size = ( 3 , * size ), dtype = torch . uint8 ) )
892+ feature_image = make_image ( size = size , color_space = features . ColorSpace . RGB )
887893 target = {
888894 "boxes" : make_bounding_box (image_size = size , format = "XYXY" , extra_dims = (num_objects ,), dtype = torch .float ),
889895 "labels" : make_label (extra_dims = (num_objects ,), categories = 80 ),
@@ -927,3 +933,165 @@ def test_transform(self, t_ref, t, data_kwargs):
927933 expected_output = t_ref (* dp )
928934
929935 assert_equal (expected_output , output )
936+
937+
938+ seg_transforms = import_transforms_from_references ("segmentation" )
939+
940+
941+ # We need this transform for two reasons:
942+ # 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
943+ # counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
944+ # 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
945+ class PadIfSmaller (prototype_transforms .Transform ):
946+ def __init__ (self , size , fill = 0 ):
947+ super ().__init__ ()
948+ self .size = size
949+ self .fill = prototype_transforms ._geometry ._setup_fill_arg (fill )
950+
951+ def _get_params (self , sample ):
952+ _ , height , width = query_chw (sample )
953+ padding = [0 , 0 , max (self .size - width , 0 ), max (self .size - height , 0 )]
954+ needs_padding = any (padding )
955+ return dict (padding = padding , needs_padding = needs_padding )
956+
957+ def _transform (self , inpt , params ):
958+ if not params ["needs_padding" ]:
959+ return inpt
960+
961+ fill = self .fill [type (inpt )]
962+ fill = F ._geometry ._convert_fill_arg (fill )
963+
964+ return F .pad (inpt , padding = params ["padding" ], fill = fill )
965+
966+
967+ class TestRefSegTransforms :
968+ def make_datapoints (self , supports_pil = True , image_dtype = torch .uint8 ):
969+ size = (256 , 640 )
970+ num_categories = 21
971+
972+ conv_fns = []
973+ if supports_pil :
974+ conv_fns .append (to_image_pil )
975+ conv_fns .extend ([torch .Tensor , lambda x : x ])
976+
977+ for conv_fn in conv_fns :
978+ feature_image = make_image (size = size , color_space = features .ColorSpace .RGB , dtype = image_dtype )
979+ feature_mask = make_segmentation_mask (size = size , num_categories = num_categories , dtype = torch .uint8 )
980+
981+ dp = (conv_fn (feature_image ), feature_mask )
982+ dp_ref = (
983+ to_image_pil (feature_image ) if supports_pil else torch .Tensor (feature_image ),
984+ to_image_pil (feature_mask ),
985+ )
986+
987+ yield dp , dp_ref
988+
989+ def set_seed (self , seed = 12 ):
990+ torch .manual_seed (seed )
991+ random .seed (seed )
992+
993+ def check (self , t , t_ref , data_kwargs = None ):
994+ for dp , dp_ref in self .make_datapoints (** data_kwargs or dict ()):
995+
996+ self .set_seed ()
997+ output = t (dp )
998+
999+ self .set_seed ()
1000+ expected_output = t_ref (* dp_ref )
1001+
1002+ assert_equal (output , expected_output )
1003+
1004+ @pytest .mark .parametrize (
1005+ ("t_ref" , "t" , "data_kwargs" ),
1006+ [
1007+ (
1008+ seg_transforms .RandomHorizontalFlip (flip_prob = 1.0 ),
1009+ prototype_transforms .RandomHorizontalFlip (p = 1.0 ),
1010+ dict (),
1011+ ),
1012+ (
1013+ seg_transforms .RandomHorizontalFlip (flip_prob = 0.0 ),
1014+ prototype_transforms .RandomHorizontalFlip (p = 0.0 ),
1015+ dict (),
1016+ ),
1017+ (
1018+ seg_transforms .RandomCrop (size = 480 ),
1019+ prototype_transforms .Compose (
1020+ [
1021+ PadIfSmaller (size = 480 , fill = defaultdict (lambda : 0 , {features .Mask : 255 })),
1022+ prototype_transforms .RandomCrop (size = 480 ),
1023+ ]
1024+ ),
1025+ dict (),
1026+ ),
1027+ (
1028+ seg_transforms .Normalize (mean = (0.485 , 0.456 , 0.406 ), std = (0.229 , 0.224 , 0.225 )),
1029+ prototype_transforms .Normalize (mean = (0.485 , 0.456 , 0.406 ), std = (0.229 , 0.224 , 0.225 )),
1030+ dict (supports_pil = False , image_dtype = torch .float ),
1031+ ),
1032+ ],
1033+ )
1034+ def test_common (self , t_ref , t , data_kwargs ):
1035+ self .check (t , t_ref , data_kwargs )
1036+
1037+ def check_resize (self , mocker , t_ref , t ):
1038+ mock = mocker .patch ("torchvision.prototype.transforms._geometry.F.resize" )
1039+ mock_ref = mocker .patch ("torchvision.transforms.functional.resize" )
1040+
1041+ for dp , dp_ref in self .make_datapoints ():
1042+ mock .reset_mock ()
1043+ mock_ref .reset_mock ()
1044+
1045+ self .set_seed ()
1046+ t (dp )
1047+ assert mock .call_count == 2
1048+ assert all (
1049+ actual is expected
1050+ for actual , expected in zip ([call_args [0 ][0 ] for call_args in mock .call_args_list ], dp )
1051+ )
1052+
1053+ self .set_seed ()
1054+ t_ref (* dp_ref )
1055+ assert mock_ref .call_count == 2
1056+ assert all (
1057+ actual is expected
1058+ for actual , expected in zip ([call_args [0 ][0 ] for call_args in mock_ref .call_args_list ], dp_ref )
1059+ )
1060+
1061+ for args_kwargs , args_kwargs_ref in zip (mock .call_args_list , mock_ref .call_args_list ):
1062+ assert args_kwargs [0 ][1 ] == [args_kwargs_ref [0 ][1 ]]
1063+
1064+ def test_random_resize_train (self , mocker ):
1065+ base_size = 520
1066+ min_size = base_size // 2
1067+ max_size = base_size * 2
1068+
1069+ randint = torch .randint
1070+
1071+ def patched_randint (a , b , * other_args , ** kwargs ):
1072+ if kwargs or len (other_args ) > 1 or other_args [0 ] != ():
1073+ return randint (a , b , * other_args , ** kwargs )
1074+
1075+ return random .randint (a , b )
1076+
1077+ # We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
1078+ # normally
1079+ t = prototype_transforms .RandomResize (min_size = min_size , max_size = max_size , antialias = True )
1080+ mocker .patch (
1081+ "torchvision.prototype.transforms._geometry.torch.randint" ,
1082+ new = patched_randint ,
1083+ )
1084+
1085+ t_ref = seg_transforms .RandomResize (min_size = min_size , max_size = max_size )
1086+
1087+ self .check_resize (mocker , t_ref , t )
1088+
1089+ def test_random_resize_eval (self , mocker ):
1090+ torch .manual_seed (0 )
1091+ base_size = 520
1092+
1093+ t = prototype_transforms .Resize (size = base_size , antialias = True )
1094+
1095+ t_ref = seg_transforms .RandomResize (min_size = base_size , max_size = base_size )
1096+
1097+ self .check_resize (mocker , t_ref , t )
0 commit comments