@@ -705,281 +705,6 @@ def test_to_tensor(self):
705705 assert_equal (prototype_transform (image_numpy ), legacy_transform (image_numpy ))
706706
707707
708- class TestAATransforms :
709- @pytest .mark .parametrize (
710- "inpt" ,
711- [
712- torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 ),
713- PIL .Image .new ("RGB" , (256 , 256 ), 123 ),
714- tv_tensors .Image (torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )),
715- ],
716- )
717- @pytest .mark .parametrize (
718- "interpolation" ,
719- [
720- v2_transforms .InterpolationMode .NEAREST ,
721- v2_transforms .InterpolationMode .BILINEAR ,
722- PIL .Image .NEAREST ,
723- ],
724- )
725- def test_randaug (self , inpt , interpolation , mocker ):
726- t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
727- t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
728-
729- le = len (t ._AUGMENTATION_SPACE )
730- keys = list (t ._AUGMENTATION_SPACE .keys ())
731- randint_values = []
732- for i in range (le ):
733- # Stable API, op_index random call
734- randint_values .append (i )
735- # Stable API, if signed there is another random call
736- if t ._AUGMENTATION_SPACE [keys [i ]][1 ]:
737- randint_values .append (0 )
738- # New API, _get_random_item
739- randint_values .append (i )
740- randint_values = iter (randint_values )
741-
742- mocker .patch ("torch.randint" , side_effect = lambda * arg , ** kwargs : torch .tensor (next (randint_values )))
743- mocker .patch ("torch.rand" , return_value = 1.0 )
744-
745- for i in range (le ):
746- expected_output = t_ref (inpt )
747- output = t (inpt )
748-
749- assert_close (expected_output , output , atol = 1 , rtol = 0.1 )
750-
751- @pytest .mark .parametrize (
752- "interpolation" ,
753- [
754- v2_transforms .InterpolationMode .NEAREST ,
755- v2_transforms .InterpolationMode .BILINEAR ,
756- ],
757- )
758- @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
759- def test_randaug_jit (self , interpolation , fill ):
760- inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
761- t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 , fill = fill )
762- t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 , fill = fill )
763-
764- tt_ref = torch .jit .script (t_ref )
765- tt = torch .jit .script (t )
766-
767- torch .manual_seed (12 )
768- expected_output = tt_ref (inpt )
769-
770- torch .manual_seed (12 )
771- scripted_output = tt (inpt )
772-
773- assert_equal (scripted_output , expected_output )
774-
775- @pytest .mark .parametrize (
776- "inpt" ,
777- [
778- torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 ),
779- PIL .Image .new ("RGB" , (256 , 256 ), 123 ),
780- tv_tensors .Image (torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )),
781- ],
782- )
783- @pytest .mark .parametrize (
784- "interpolation" ,
785- [
786- v2_transforms .InterpolationMode .NEAREST ,
787- v2_transforms .InterpolationMode .BILINEAR ,
788- PIL .Image .NEAREST ,
789- ],
790- )
791- def test_trivial_aug (self , inpt , interpolation , mocker ):
792- t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation )
793- t = v2_transforms .TrivialAugmentWide (interpolation = interpolation )
794-
795- le = len (t ._AUGMENTATION_SPACE )
796- keys = list (t ._AUGMENTATION_SPACE .keys ())
797- randint_values = []
798- for i in range (le ):
799- # Stable API, op_index random call
800- randint_values .append (i )
801- key = keys [i ]
802- # Stable API, random magnitude
803- aug_op = t ._AUGMENTATION_SPACE [key ]
804- magnitudes = aug_op [0 ](2 , 0 , 0 )
805- if magnitudes is not None :
806- randint_values .append (5 )
807- # Stable API, if signed there is another random call
808- if aug_op [1 ]:
809- randint_values .append (0 )
810- # New API, _get_random_item
811- randint_values .append (i )
812- # New API, random magnitude
813- if magnitudes is not None :
814- randint_values .append (5 )
815-
816- randint_values = iter (randint_values )
817-
818- mocker .patch ("torch.randint" , side_effect = lambda * arg , ** kwargs : torch .tensor (next (randint_values )))
819- mocker .patch ("torch.rand" , return_value = 1.0 )
820-
821- for _ in range (le ):
822- expected_output = t_ref (inpt )
823- output = t (inpt )
824-
825- assert_close (expected_output , output , atol = 1 , rtol = 0.1 )
826-
827- @pytest .mark .parametrize (
828- "interpolation" ,
829- [
830- v2_transforms .InterpolationMode .NEAREST ,
831- v2_transforms .InterpolationMode .BILINEAR ,
832- ],
833- )
834- @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
835- def test_trivial_aug_jit (self , interpolation , fill ):
836- inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
837- t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation , fill = fill )
838- t = v2_transforms .TrivialAugmentWide (interpolation = interpolation , fill = fill )
839-
840- tt_ref = torch .jit .script (t_ref )
841- tt = torch .jit .script (t )
842-
843- torch .manual_seed (12 )
844- expected_output = tt_ref (inpt )
845-
846- torch .manual_seed (12 )
847- scripted_output = tt (inpt )
848-
849- assert_equal (scripted_output , expected_output )
850-
851- @pytest .mark .parametrize (
852- "inpt" ,
853- [
854- torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 ),
855- PIL .Image .new ("RGB" , (256 , 256 ), 123 ),
856- tv_tensors .Image (torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )),
857- ],
858- )
859- @pytest .mark .parametrize (
860- "interpolation" ,
861- [
862- v2_transforms .InterpolationMode .NEAREST ,
863- v2_transforms .InterpolationMode .BILINEAR ,
864- PIL .Image .NEAREST ,
865- ],
866- )
867- def test_augmix (self , inpt , interpolation , mocker ):
868- t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
869- t_ref ._sample_dirichlet = lambda t : t .softmax (dim = - 1 )
870- t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
871- t ._sample_dirichlet = lambda t : t .softmax (dim = - 1 )
872-
873- le = len (t ._AUGMENTATION_SPACE )
874- keys = list (t ._AUGMENTATION_SPACE .keys ())
875- randint_values = []
876- for i in range (le ):
877- # Stable API, op_index random call
878- randint_values .append (i )
879- key = keys [i ]
880- # Stable API, random magnitude
881- aug_op = t ._AUGMENTATION_SPACE [key ]
882- magnitudes = aug_op [0 ](2 , 0 , 0 )
883- if magnitudes is not None :
884- randint_values .append (5 )
885- # Stable API, if signed there is another random call
886- if aug_op [1 ]:
887- randint_values .append (0 )
888- # New API, _get_random_item
889- randint_values .append (i )
890- # New API, random magnitude
891- if magnitudes is not None :
892- randint_values .append (5 )
893-
894- randint_values = iter (randint_values )
895-
896- mocker .patch ("torch.randint" , side_effect = lambda * arg , ** kwargs : torch .tensor (next (randint_values )))
897- mocker .patch ("torch.rand" , return_value = 1.0 )
898-
899- expected_output = t_ref (inpt )
900- output = t (inpt )
901-
902- assert_equal (expected_output , output )
903-
904- @pytest .mark .parametrize (
905- "interpolation" ,
906- [
907- v2_transforms .InterpolationMode .NEAREST ,
908- v2_transforms .InterpolationMode .BILINEAR ,
909- ],
910- )
911- @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
912- def test_augmix_jit (self , interpolation , fill ):
913- inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
914-
915- t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 , fill = fill )
916- t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 , fill = fill )
917-
918- tt_ref = torch .jit .script (t_ref )
919- tt = torch .jit .script (t )
920-
921- torch .manual_seed (12 )
922- expected_output = tt_ref (inpt )
923-
924- torch .manual_seed (12 )
925- scripted_output = tt (inpt )
926-
927- assert_equal (scripted_output , expected_output )
928-
929- @pytest .mark .parametrize (
930- "inpt" ,
931- [
932- torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 ),
933- PIL .Image .new ("RGB" , (256 , 256 ), 123 ),
934- tv_tensors .Image (torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )),
935- ],
936- )
937- @pytest .mark .parametrize (
938- "interpolation" ,
939- [
940- v2_transforms .InterpolationMode .NEAREST ,
941- v2_transforms .InterpolationMode .BILINEAR ,
942- PIL .Image .NEAREST ,
943- ],
944- )
945- def test_aa (self , inpt , interpolation ):
946- aa_policy = legacy_transforms .AutoAugmentPolicy ("imagenet" )
947- t_ref = legacy_transforms .AutoAugment (aa_policy , interpolation = interpolation )
948- t = v2_transforms .AutoAugment (aa_policy , interpolation = interpolation )
949-
950- torch .manual_seed (12 )
951- expected_output = t_ref (inpt )
952-
953- torch .manual_seed (12 )
954- output = t (inpt )
955-
956- assert_equal (expected_output , output )
957-
958- @pytest .mark .parametrize (
959- "interpolation" ,
960- [
961- v2_transforms .InterpolationMode .NEAREST ,
962- v2_transforms .InterpolationMode .BILINEAR ,
963- ],
964- )
965- def test_aa_jit (self , interpolation ):
966- inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
967- aa_policy = legacy_transforms .AutoAugmentPolicy ("imagenet" )
968- t_ref = legacy_transforms .AutoAugment (aa_policy , interpolation = interpolation )
969- t = v2_transforms .AutoAugment (aa_policy , interpolation = interpolation )
970-
971- tt_ref = torch .jit .script (t_ref )
972- tt = torch .jit .script (t )
973-
974- torch .manual_seed (12 )
975- expected_output = tt_ref (inpt )
976-
977- torch .manual_seed (12 )
978- scripted_output = tt (inpt )
979-
980- assert_equal (scripted_output , expected_output )
981-
982-
983708def import_transforms_from_references (reference ):
984709 HERE = Path (__file__ ).parent
985710 PROJECT_ROOT = HERE .parent
0 commit comments