Skip to content

Commit 16b9a40

Browse files
vmoensVincent Moens
andauthored
Pytest for test_videoapi.py and test_video_reader.py (#4233)
* test_video_reader pytest refactoring * pytest refactoring of test_videoapi.py * test_video_reader pytest refactoring * pytest refactoring of test_videoapi.py * using pytest.approx for test_video_reader.py * using pytest.approx for test_videoapi.py * Fixing minor comments * linting fixes * minor comments Co-authored-by: Vincent Moens <[email protected]>
1 parent a839796 commit 16b9a40

File tree

3 files changed

+80
-225
lines changed

3 files changed

+80
-225
lines changed

test/common_utils.py

Lines changed: 2 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tempfile
44
import contextlib
55
import unittest
6+
import pytest
67
import argparse
78
import sys
89
import torch
@@ -20,7 +21,7 @@
2021

2122
IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9
2223
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367"
23-
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG)
24+
PY39_SKIP = pytest.mark.skipif(IS_PY39, reason=PY39_SEGFAULT_SKIP_MSG)
2425
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true'
2526
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
2627
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
@@ -83,129 +84,6 @@ def is_iterable(obj):
8384
return False
8485

8586

86-
# adapted from TestCase in torch/test/common_utils to accept non-string
87-
# inputs and set maximum binary size
88-
class TestCase(unittest.TestCase):
89-
precision = 1e-5
90-
91-
def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
92-
"""
93-
This is copied from pytorch/test/common_utils.py's TestCase.assertEqual
94-
"""
95-
if isinstance(prec, str) and message == '':
96-
message = prec
97-
prec = None
98-
if prec is None:
99-
prec = self.precision
100-
101-
if isinstance(x, torch.Tensor) and isinstance(y, Number):
102-
self.assertEqual(x.item(), y, prec=prec, message=message,
103-
allow_inf=allow_inf)
104-
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
105-
self.assertEqual(x, y.item(), prec=prec, message=message,
106-
allow_inf=allow_inf)
107-
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
108-
def assertTensorsEqual(a, b):
109-
super(TestCase, self).assertEqual(a.size(), b.size(), message)
110-
if a.numel() > 0:
111-
if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)):
112-
# CPU half and bfloat16 tensors don't have the methods we need below
113-
a = a.to(torch.float32)
114-
b = b.to(a)
115-
116-
if (a.dtype == torch.bool) != (b.dtype == torch.bool):
117-
raise TypeError("Was expecting both tensors to be bool type.")
118-
else:
119-
if a.dtype == torch.bool and b.dtype == torch.bool:
120-
# we want to respect precision but as bool doesn't support substraction,
121-
# boolean tensor has to be converted to int
122-
a = a.to(torch.int)
123-
b = b.to(torch.int)
124-
125-
diff = a - b
126-
if a.is_floating_point():
127-
# check that NaNs are in the same locations
128-
nan_mask = torch.isnan(a)
129-
self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
130-
diff[nan_mask] = 0
131-
# inf check if allow_inf=True
132-
if allow_inf:
133-
inf_mask = torch.isinf(a)
134-
inf_sign = inf_mask.sign()
135-
self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
136-
diff[inf_mask] = 0
137-
# TODO: implement abs on CharTensor (int8)
138-
if diff.is_signed() and diff.dtype != torch.int8:
139-
diff = diff.abs()
140-
max_err = diff.max()
141-
tolerance = prec + prec * abs(a.max())
142-
self.assertLessEqual(max_err, tolerance, message)
143-
super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
144-
super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message)
145-
if x.is_sparse:
146-
x = self.safeCoalesce(x)
147-
y = self.safeCoalesce(y)
148-
assertTensorsEqual(x._indices(), y._indices())
149-
assertTensorsEqual(x._values(), y._values())
150-
elif x.is_quantized and y.is_quantized:
151-
self.assertEqual(x.qscheme(), y.qscheme(), prec=prec,
152-
message=message, allow_inf=allow_inf)
153-
if x.qscheme() == torch.per_tensor_affine:
154-
self.assertEqual(x.q_scale(), y.q_scale(), prec=prec,
155-
message=message, allow_inf=allow_inf)
156-
self.assertEqual(x.q_zero_point(), y.q_zero_point(),
157-
prec=prec, message=message,
158-
allow_inf=allow_inf)
159-
elif x.qscheme() == torch.per_channel_affine:
160-
self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec,
161-
message=message, allow_inf=allow_inf)
162-
self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(),
163-
prec=prec, message=message,
164-
allow_inf=allow_inf)
165-
self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(),
166-
prec=prec, message=message)
167-
self.assertEqual(x.dtype, y.dtype)
168-
self.assertEqual(x.int_repr().to(torch.int32),
169-
y.int_repr().to(torch.int32), prec=prec,
170-
message=message, allow_inf=allow_inf)
171-
else:
172-
assertTensorsEqual(x, y)
173-
elif isinstance(x, string_classes) and isinstance(y, string_classes):
174-
super(TestCase, self).assertEqual(x, y, message)
175-
elif type(x) == set and type(y) == set:
176-
super(TestCase, self).assertEqual(x, y, message)
177-
elif isinstance(x, dict) and isinstance(y, dict):
178-
if isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
179-
self.assertEqual(x.items(), y.items(), prec=prec,
180-
message=message, allow_inf=allow_inf)
181-
else:
182-
self.assertEqual(set(x.keys()), set(y.keys()), prec=prec,
183-
message=message, allow_inf=allow_inf)
184-
key_list = list(x.keys())
185-
self.assertEqual([x[k] for k in key_list],
186-
[y[k] for k in key_list],
187-
prec=prec, message=message,
188-
allow_inf=allow_inf)
189-
elif is_iterable(x) and is_iterable(y):
190-
super(TestCase, self).assertEqual(len(x), len(y), message)
191-
for x_, y_ in zip(x, y):
192-
self.assertEqual(x_, y_, prec=prec, message=message,
193-
allow_inf=allow_inf)
194-
elif isinstance(x, bool) and isinstance(y, bool):
195-
super(TestCase, self).assertEqual(x, y, message)
196-
elif isinstance(x, Number) and isinstance(y, Number):
197-
inf = float("inf")
198-
if abs(x) == inf or abs(y) == inf:
199-
if allow_inf:
200-
super(TestCase, self).assertEqual(x, y, message)
201-
else:
202-
self.fail("Expected finite numeric values - x={}, y={}".format(x, y))
203-
return
204-
super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
205-
else:
206-
super(TestCase, self).assertEqual(x, y, message)
207-
208-
20987
@contextlib.contextmanager
21088
def freeze_rng_state():
21189
rng_state = torch.get_rng_state()

0 commit comments

Comments
 (0)