Skip to content

Commit 21426dd

Browse files
Port some tests in test_transforms.py to pytest (#3964)
1 parent a0b44d7 commit 21426dd

File tree

1 file changed

+91
-88
lines changed

1 file changed

+91
-88
lines changed

test/test_transforms.py

Lines changed: 91 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -403,94 +403,6 @@ def test_random_crop(self):
403403
with self.assertRaisesRegex(ValueError, r"Required crop size .+ is larger then input image size .+"):
404404
t(img)
405405

406-
@unittest.skipIf(stats is None, 'scipy.stats not available')
407-
def test_random_apply(self):
408-
random_state = random.getstate()
409-
random.seed(42)
410-
random_apply_transform = transforms.RandomApply(
411-
[
412-
transforms.RandomRotation((-45, 45)),
413-
transforms.RandomHorizontalFlip(),
414-
transforms.RandomVerticalFlip(),
415-
], p=0.75
416-
)
417-
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
418-
num_samples = 250
419-
num_applies = 0
420-
for _ in range(num_samples):
421-
out = random_apply_transform(img)
422-
if out != img:
423-
num_applies += 1
424-
425-
p_value = stats.binom_test(num_applies, num_samples, p=0.75)
426-
random.setstate(random_state)
427-
self.assertGreater(p_value, 0.0001)
428-
429-
# Checking if RandomApply can be printed as string
430-
random_apply_transform.__repr__()
431-
432-
@unittest.skipIf(stats is None, 'scipy.stats not available')
433-
def test_random_choice(self):
434-
random_state = random.getstate()
435-
random.seed(42)
436-
random_choice_transform = transforms.RandomChoice(
437-
[
438-
transforms.Resize(15),
439-
transforms.Resize(20),
440-
transforms.CenterCrop(10)
441-
]
442-
)
443-
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
444-
num_samples = 250
445-
num_resize_15 = 0
446-
num_resize_20 = 0
447-
num_crop_10 = 0
448-
for _ in range(num_samples):
449-
out = random_choice_transform(img)
450-
if out.size == (15, 15):
451-
num_resize_15 += 1
452-
elif out.size == (20, 20):
453-
num_resize_20 += 1
454-
elif out.size == (10, 10):
455-
num_crop_10 += 1
456-
457-
p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
458-
self.assertGreater(p_value, 0.0001)
459-
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
460-
self.assertGreater(p_value, 0.0001)
461-
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
462-
self.assertGreater(p_value, 0.0001)
463-
464-
random.setstate(random_state)
465-
# Checking if RandomChoice can be printed as string
466-
random_choice_transform.__repr__()
467-
468-
@unittest.skipIf(stats is None, 'scipy.stats not available')
469-
def test_random_order(self):
470-
random_state = random.getstate()
471-
random.seed(42)
472-
random_order_transform = transforms.RandomOrder(
473-
[
474-
transforms.Resize(20),
475-
transforms.CenterCrop(10)
476-
]
477-
)
478-
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
479-
num_samples = 250
480-
num_normal_order = 0
481-
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
482-
for _ in range(num_samples):
483-
out = random_order_transform(img)
484-
if out == resize_crop_out:
485-
num_normal_order += 1
486-
487-
p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
488-
random.setstate(random_state)
489-
self.assertGreater(p_value, 0.0001)
490-
491-
# Checking if RandomOrder can be printed as string
492-
random_order_transform.__repr__()
493-
494406
def test_to_tensor(self):
495407
test_channels = [1, 3, 4]
496408
height, width = 4, 4
@@ -1994,5 +1906,96 @@ def test_random_grayscale():
19941906
trans3.__repr__()
19951907

19961908

1909+
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
1910+
def test_random_apply():
1911+
random_state = random.getstate()
1912+
random.seed(42)
1913+
random_apply_transform = transforms.RandomApply(
1914+
[
1915+
transforms.RandomRotation((-45, 45)),
1916+
transforms.RandomHorizontalFlip(),
1917+
transforms.RandomVerticalFlip(),
1918+
], p=0.75
1919+
)
1920+
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
1921+
num_samples = 250
1922+
num_applies = 0
1923+
for _ in range(num_samples):
1924+
out = random_apply_transform(img)
1925+
if out != img:
1926+
num_applies += 1
1927+
1928+
p_value = stats.binom_test(num_applies, num_samples, p=0.75)
1929+
random.setstate(random_state)
1930+
assert p_value > 0.0001
1931+
1932+
# Checking if RandomApply can be printed as string
1933+
random_apply_transform.__repr__()
1934+
1935+
1936+
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
1937+
def test_random_choice():
1938+
random_state = random.getstate()
1939+
random.seed(42)
1940+
random_choice_transform = transforms.RandomChoice(
1941+
[
1942+
transforms.Resize(15),
1943+
transforms.Resize(20),
1944+
transforms.CenterCrop(10)
1945+
]
1946+
)
1947+
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
1948+
num_samples = 250
1949+
num_resize_15 = 0
1950+
num_resize_20 = 0
1951+
num_crop_10 = 0
1952+
for _ in range(num_samples):
1953+
out = random_choice_transform(img)
1954+
if out.size == (15, 15):
1955+
num_resize_15 += 1
1956+
elif out.size == (20, 20):
1957+
num_resize_20 += 1
1958+
elif out.size == (10, 10):
1959+
num_crop_10 += 1
1960+
1961+
p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333)
1962+
assert p_value > 0.0001
1963+
p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333)
1964+
assert p_value > 0.0001
1965+
p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333)
1966+
assert p_value > 0.0001
1967+
1968+
random.setstate(random_state)
1969+
# Checking if RandomChoice can be printed as string
1970+
random_choice_transform.__repr__()
1971+
1972+
1973+
@pytest.mark.skipif(stats is None, reason='scipy.stats not available')
1974+
def test_random_order():
1975+
random_state = random.getstate()
1976+
random.seed(42)
1977+
random_order_transform = transforms.RandomOrder(
1978+
[
1979+
transforms.Resize(20),
1980+
transforms.CenterCrop(10)
1981+
]
1982+
)
1983+
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
1984+
num_samples = 250
1985+
num_normal_order = 0
1986+
resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20)(img))
1987+
for _ in range(num_samples):
1988+
out = random_order_transform(img)
1989+
if out == resize_crop_out:
1990+
num_normal_order += 1
1991+
1992+
p_value = stats.binom_test(num_normal_order, num_samples, p=0.5)
1993+
random.setstate(random_state)
1994+
assert p_value > 0.0001
1995+
1996+
# Checking if RandomOrder can be printed as string
1997+
random_order_transform.__repr__()
1998+
1999+
19972000
if __name__ == '__main__':
19982001
unittest.main()

0 commit comments

Comments
 (0)