@@ -403,94 +403,6 @@ def test_random_crop(self):
403403 with self .assertRaisesRegex (ValueError , r"Required crop size .+ is larger then input image size .+" ):
404404 t (img )
405405
406- @unittest .skipIf (stats is None , 'scipy.stats not available' )
407- def test_random_apply (self ):
408- random_state = random .getstate ()
409- random .seed (42 )
410- random_apply_transform = transforms .RandomApply (
411- [
412- transforms .RandomRotation ((- 45 , 45 )),
413- transforms .RandomHorizontalFlip (),
414- transforms .RandomVerticalFlip (),
415- ], p = 0.75
416- )
417- img = transforms .ToPILImage ()(torch .rand (3 , 10 , 10 ))
418- num_samples = 250
419- num_applies = 0
420- for _ in range (num_samples ):
421- out = random_apply_transform (img )
422- if out != img :
423- num_applies += 1
424-
425- p_value = stats .binom_test (num_applies , num_samples , p = 0.75 )
426- random .setstate (random_state )
427- self .assertGreater (p_value , 0.0001 )
428-
429- # Checking if RandomApply can be printed as string
430- random_apply_transform .__repr__ ()
431-
432- @unittest .skipIf (stats is None , 'scipy.stats not available' )
433- def test_random_choice (self ):
434- random_state = random .getstate ()
435- random .seed (42 )
436- random_choice_transform = transforms .RandomChoice (
437- [
438- transforms .Resize (15 ),
439- transforms .Resize (20 ),
440- transforms .CenterCrop (10 )
441- ]
442- )
443- img = transforms .ToPILImage ()(torch .rand (3 , 25 , 25 ))
444- num_samples = 250
445- num_resize_15 = 0
446- num_resize_20 = 0
447- num_crop_10 = 0
448- for _ in range (num_samples ):
449- out = random_choice_transform (img )
450- if out .size == (15 , 15 ):
451- num_resize_15 += 1
452- elif out .size == (20 , 20 ):
453- num_resize_20 += 1
454- elif out .size == (10 , 10 ):
455- num_crop_10 += 1
456-
457- p_value = stats .binom_test (num_resize_15 , num_samples , p = 0.33333 )
458- self .assertGreater (p_value , 0.0001 )
459- p_value = stats .binom_test (num_resize_20 , num_samples , p = 0.33333 )
460- self .assertGreater (p_value , 0.0001 )
461- p_value = stats .binom_test (num_crop_10 , num_samples , p = 0.33333 )
462- self .assertGreater (p_value , 0.0001 )
463-
464- random .setstate (random_state )
465- # Checking if RandomChoice can be printed as string
466- random_choice_transform .__repr__ ()
467-
468- @unittest .skipIf (stats is None , 'scipy.stats not available' )
469- def test_random_order (self ):
470- random_state = random .getstate ()
471- random .seed (42 )
472- random_order_transform = transforms .RandomOrder (
473- [
474- transforms .Resize (20 ),
475- transforms .CenterCrop (10 )
476- ]
477- )
478- img = transforms .ToPILImage ()(torch .rand (3 , 25 , 25 ))
479- num_samples = 250
480- num_normal_order = 0
481- resize_crop_out = transforms .CenterCrop (10 )(transforms .Resize (20 )(img ))
482- for _ in range (num_samples ):
483- out = random_order_transform (img )
484- if out == resize_crop_out :
485- num_normal_order += 1
486-
487- p_value = stats .binom_test (num_normal_order , num_samples , p = 0.5 )
488- random .setstate (random_state )
489- self .assertGreater (p_value , 0.0001 )
490-
491- # Checking if RandomOrder can be printed as string
492- random_order_transform .__repr__ ()
493-
494406 def test_to_tensor (self ):
495407 test_channels = [1 , 3 , 4 ]
496408 height , width = 4 , 4
@@ -1994,5 +1906,96 @@ def test_random_grayscale():
19941906 trans3 .__repr__ ()
19951907
19961908
1909+ @pytest .mark .skipif (stats is None , reason = 'scipy.stats not available' )
1910+ def test_random_apply ():
1911+ random_state = random .getstate ()
1912+ random .seed (42 )
1913+ random_apply_transform = transforms .RandomApply (
1914+ [
1915+ transforms .RandomRotation ((- 45 , 45 )),
1916+ transforms .RandomHorizontalFlip (),
1917+ transforms .RandomVerticalFlip (),
1918+ ], p = 0.75
1919+ )
1920+ img = transforms .ToPILImage ()(torch .rand (3 , 10 , 10 ))
1921+ num_samples = 250
1922+ num_applies = 0
1923+ for _ in range (num_samples ):
1924+ out = random_apply_transform (img )
1925+ if out != img :
1926+ num_applies += 1
1927+
1928+ p_value = stats .binom_test (num_applies , num_samples , p = 0.75 )
1929+ random .setstate (random_state )
1930+ assert p_value > 0.0001
1931+
1932+ # Checking if RandomApply can be printed as string
1933+ random_apply_transform .__repr__ ()
1934+
1935+
1936+ @pytest .mark .skipif (stats is None , reason = 'scipy.stats not available' )
1937+ def test_random_choice ():
1938+ random_state = random .getstate ()
1939+ random .seed (42 )
1940+ random_choice_transform = transforms .RandomChoice (
1941+ [
1942+ transforms .Resize (15 ),
1943+ transforms .Resize (20 ),
1944+ transforms .CenterCrop (10 )
1945+ ]
1946+ )
1947+ img = transforms .ToPILImage ()(torch .rand (3 , 25 , 25 ))
1948+ num_samples = 250
1949+ num_resize_15 = 0
1950+ num_resize_20 = 0
1951+ num_crop_10 = 0
1952+ for _ in range (num_samples ):
1953+ out = random_choice_transform (img )
1954+ if out .size == (15 , 15 ):
1955+ num_resize_15 += 1
1956+ elif out .size == (20 , 20 ):
1957+ num_resize_20 += 1
1958+ elif out .size == (10 , 10 ):
1959+ num_crop_10 += 1
1960+
1961+ p_value = stats .binom_test (num_resize_15 , num_samples , p = 0.33333 )
1962+ assert p_value > 0.0001
1963+ p_value = stats .binom_test (num_resize_20 , num_samples , p = 0.33333 )
1964+ assert p_value > 0.0001
1965+ p_value = stats .binom_test (num_crop_10 , num_samples , p = 0.33333 )
1966+ assert p_value > 0.0001
1967+
1968+ random .setstate (random_state )
1969+ # Checking if RandomChoice can be printed as string
1970+ random_choice_transform .__repr__ ()
1971+
1972+
1973+ @pytest .mark .skipif (stats is None , reason = 'scipy.stats not available' )
1974+ def test_random_order ():
1975+ random_state = random .getstate ()
1976+ random .seed (42 )
1977+ random_order_transform = transforms .RandomOrder (
1978+ [
1979+ transforms .Resize (20 ),
1980+ transforms .CenterCrop (10 )
1981+ ]
1982+ )
1983+ img = transforms .ToPILImage ()(torch .rand (3 , 25 , 25 ))
1984+ num_samples = 250
1985+ num_normal_order = 0
1986+ resize_crop_out = transforms .CenterCrop (10 )(transforms .Resize (20 )(img ))
1987+ for _ in range (num_samples ):
1988+ out = random_order_transform (img )
1989+ if out == resize_crop_out :
1990+ num_normal_order += 1
1991+
1992+ p_value = stats .binom_test (num_normal_order , num_samples , p = 0.5 )
1993+ random .setstate (random_state )
1994+ assert p_value > 0.0001
1995+
1996+ # Checking if RandomOrder can be printed as string
1997+ random_order_transform .__repr__ ()
1998+
1999+
19972000if __name__ == '__main__' :
19982001 unittest .main ()
0 commit comments