|
3 | 3 | import tempfile |
4 | 4 | import contextlib |
5 | 5 | import unittest |
| 6 | +import pytest |
6 | 7 | import argparse |
7 | 8 | import sys |
8 | 9 | import torch |
|
20 | 21 |
|
21 | 22 | IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9 |
22 | 23 | PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" |
23 | | -PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG) |
| 24 | +PY39_SKIP = pytest.mark.skipif(IS_PY39, reason=PY39_SEGFAULT_SKIP_MSG) |
24 | 25 | IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true' |
25 | 26 | IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None |
26 | 27 | IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" |
@@ -83,129 +84,6 @@ def is_iterable(obj): |
83 | 84 | return False |
84 | 85 |
|
85 | 86 |
|
86 | | -# adapted from TestCase in torch/test/common_utils to accept non-string |
87 | | -# inputs and set maximum binary size |
88 | | -class TestCase(unittest.TestCase): |
89 | | - precision = 1e-5 |
90 | | - |
91 | | - def assertEqual(self, x, y, prec=None, message='', allow_inf=False): |
92 | | - """ |
93 | | - This is copied from pytorch/test/common_utils.py's TestCase.assertEqual |
94 | | - """ |
95 | | - if isinstance(prec, str) and message == '': |
96 | | - message = prec |
97 | | - prec = None |
98 | | - if prec is None: |
99 | | - prec = self.precision |
100 | | - |
101 | | - if isinstance(x, torch.Tensor) and isinstance(y, Number): |
102 | | - self.assertEqual(x.item(), y, prec=prec, message=message, |
103 | | - allow_inf=allow_inf) |
104 | | - elif isinstance(y, torch.Tensor) and isinstance(x, Number): |
105 | | - self.assertEqual(x, y.item(), prec=prec, message=message, |
106 | | - allow_inf=allow_inf) |
107 | | - elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): |
108 | | - def assertTensorsEqual(a, b): |
109 | | - super(TestCase, self).assertEqual(a.size(), b.size(), message) |
110 | | - if a.numel() > 0: |
111 | | - if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)): |
112 | | - # CPU half and bfloat16 tensors don't have the methods we need below |
113 | | - a = a.to(torch.float32) |
114 | | - b = b.to(a) |
115 | | - |
116 | | - if (a.dtype == torch.bool) != (b.dtype == torch.bool): |
117 | | - raise TypeError("Was expecting both tensors to be bool type.") |
118 | | - else: |
119 | | - if a.dtype == torch.bool and b.dtype == torch.bool: |
120 | | - # we want to respect precision but as bool doesn't support substraction, |
121 | | - # boolean tensor has to be converted to int |
122 | | - a = a.to(torch.int) |
123 | | - b = b.to(torch.int) |
124 | | - |
125 | | - diff = a - b |
126 | | - if a.is_floating_point(): |
127 | | - # check that NaNs are in the same locations |
128 | | - nan_mask = torch.isnan(a) |
129 | | - self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message) |
130 | | - diff[nan_mask] = 0 |
131 | | - # inf check if allow_inf=True |
132 | | - if allow_inf: |
133 | | - inf_mask = torch.isinf(a) |
134 | | - inf_sign = inf_mask.sign() |
135 | | - self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message) |
136 | | - diff[inf_mask] = 0 |
137 | | - # TODO: implement abs on CharTensor (int8) |
138 | | - if diff.is_signed() and diff.dtype != torch.int8: |
139 | | - diff = diff.abs() |
140 | | - max_err = diff.max() |
141 | | - tolerance = prec + prec * abs(a.max()) |
142 | | - self.assertLessEqual(max_err, tolerance, message) |
143 | | - super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message) |
144 | | - super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message) |
145 | | - if x.is_sparse: |
146 | | - x = self.safeCoalesce(x) |
147 | | - y = self.safeCoalesce(y) |
148 | | - assertTensorsEqual(x._indices(), y._indices()) |
149 | | - assertTensorsEqual(x._values(), y._values()) |
150 | | - elif x.is_quantized and y.is_quantized: |
151 | | - self.assertEqual(x.qscheme(), y.qscheme(), prec=prec, |
152 | | - message=message, allow_inf=allow_inf) |
153 | | - if x.qscheme() == torch.per_tensor_affine: |
154 | | - self.assertEqual(x.q_scale(), y.q_scale(), prec=prec, |
155 | | - message=message, allow_inf=allow_inf) |
156 | | - self.assertEqual(x.q_zero_point(), y.q_zero_point(), |
157 | | - prec=prec, message=message, |
158 | | - allow_inf=allow_inf) |
159 | | - elif x.qscheme() == torch.per_channel_affine: |
160 | | - self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec, |
161 | | - message=message, allow_inf=allow_inf) |
162 | | - self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(), |
163 | | - prec=prec, message=message, |
164 | | - allow_inf=allow_inf) |
165 | | - self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(), |
166 | | - prec=prec, message=message) |
167 | | - self.assertEqual(x.dtype, y.dtype) |
168 | | - self.assertEqual(x.int_repr().to(torch.int32), |
169 | | - y.int_repr().to(torch.int32), prec=prec, |
170 | | - message=message, allow_inf=allow_inf) |
171 | | - else: |
172 | | - assertTensorsEqual(x, y) |
173 | | - elif isinstance(x, string_classes) and isinstance(y, string_classes): |
174 | | - super(TestCase, self).assertEqual(x, y, message) |
175 | | - elif type(x) == set and type(y) == set: |
176 | | - super(TestCase, self).assertEqual(x, y, message) |
177 | | - elif isinstance(x, dict) and isinstance(y, dict): |
178 | | - if isinstance(x, OrderedDict) and isinstance(y, OrderedDict): |
179 | | - self.assertEqual(x.items(), y.items(), prec=prec, |
180 | | - message=message, allow_inf=allow_inf) |
181 | | - else: |
182 | | - self.assertEqual(set(x.keys()), set(y.keys()), prec=prec, |
183 | | - message=message, allow_inf=allow_inf) |
184 | | - key_list = list(x.keys()) |
185 | | - self.assertEqual([x[k] for k in key_list], |
186 | | - [y[k] for k in key_list], |
187 | | - prec=prec, message=message, |
188 | | - allow_inf=allow_inf) |
189 | | - elif is_iterable(x) and is_iterable(y): |
190 | | - super(TestCase, self).assertEqual(len(x), len(y), message) |
191 | | - for x_, y_ in zip(x, y): |
192 | | - self.assertEqual(x_, y_, prec=prec, message=message, |
193 | | - allow_inf=allow_inf) |
194 | | - elif isinstance(x, bool) and isinstance(y, bool): |
195 | | - super(TestCase, self).assertEqual(x, y, message) |
196 | | - elif isinstance(x, Number) and isinstance(y, Number): |
197 | | - inf = float("inf") |
198 | | - if abs(x) == inf or abs(y) == inf: |
199 | | - if allow_inf: |
200 | | - super(TestCase, self).assertEqual(x, y, message) |
201 | | - else: |
202 | | - self.fail("Expected finite numeric values - x={}, y={}".format(x, y)) |
203 | | - return |
204 | | - super(TestCase, self).assertLessEqual(abs(x - y), prec, message) |
205 | | - else: |
206 | | - super(TestCase, self).assertEqual(x, y, message) |
207 | | - |
208 | | - |
209 | 87 | @contextlib.contextmanager |
210 | 88 | def freeze_rng_state(): |
211 | 89 | rng_state = torch.get_rng_state() |
|
0 commit comments