|
6 | 6 | import numpy as np |
7 | 7 |
|
8 | 8 | import torch |
9 | | -from functools import lru_cache |
| 9 | +from functools import lru_cache, partial |
10 | 10 | from torch import Tensor |
11 | 11 | from torch.autograd import gradcheck |
12 | 12 | from torch.nn.modules.utils import _pair |
@@ -64,12 +64,13 @@ def func(z): |
64 | 64 | gradcheck(func, (x,)) |
65 | 65 | gradcheck(script_func, (x,)) |
66 | 66 |
|
67 | | - @needs_cuda |
| 67 | + @pytest.mark.parametrize('device', cpu_and_gpu()) |
68 | 68 | @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) |
69 | 69 | @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) |
70 | | - def test_autocast(self, x_dtype, rois_dtype): |
71 | | - with torch.cuda.amp.autocast(): |
72 | | - self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) |
| 70 | + def test_autocast(self, device, x_dtype, rois_dtype): |
| 71 | + cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast |
| 72 | + with cm(): |
| 73 | + self.test_forward(torch.device(device), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) |
73 | 74 |
|
74 | 75 | def _helper_boxes_shape(self, func): |
75 | 76 | # test boxes as Tensor[N, 5] |
@@ -269,13 +270,14 @@ def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=Non |
269 | 270 | super().test_forward(device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, |
270 | 271 | aligned=aligned) |
271 | 272 |
|
272 | | - @needs_cuda |
| 273 | + @pytest.mark.parametrize('device', cpu_and_gpu()) |
273 | 274 | @pytest.mark.parametrize('aligned', (True, False)) |
274 | 275 | @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) |
275 | 276 | @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) |
276 | | - def test_autocast(self, aligned, x_dtype, rois_dtype): |
277 | | - with torch.cuda.amp.autocast(): |
278 | | - self.test_forward(torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, |
| 277 | + def test_autocast(self, device, aligned, x_dtype, rois_dtype): |
| 278 | + cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast |
| 279 | + with cm(): |
| 280 | + self.test_forward(torch.device(device), contiguous=False, aligned=aligned, x_dtype=x_dtype, |
279 | 281 | rois_dtype=rois_dtype) |
280 | 282 |
|
281 | 283 | def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): |
@@ -514,12 +516,14 @@ def test_nms_cuda(self, iou, dtype=torch.float64): |
514 | 516 | is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) |
515 | 517 | assert is_eq, err_msg.format(iou) |
516 | 518 |
|
517 | | - @needs_cuda |
| 519 | + @pytest.mark.parametrize('device', cpu_and_gpu()) |
518 | 520 | @pytest.mark.parametrize("iou", (.2, .5, .8)) |
519 | 521 | @pytest.mark.parametrize("dtype", (torch.float, torch.half)) |
520 | | - def test_autocast(self, iou, dtype): |
521 | | - with torch.cuda.amp.autocast(): |
522 | | - self.test_nms_cuda(iou=iou, dtype=dtype) |
| 522 | + def test_autocast(self, device, iou, dtype): |
| 523 | + test_fn = self.test_nms_ref if device == 'cpu' else partial(self.test_nms_cuda, dtype=dtype) |
| 524 | + cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast |
| 525 | + with cm(): |
| 526 | + test_fn(iou=iou) |
523 | 527 |
|
524 | 528 | @needs_cuda |
525 | 529 | def test_nms_cuda_float16(self): |
@@ -767,12 +771,13 @@ def test_compare_cpu_cuda_grads(self, contiguous): |
767 | 771 | res_grads = init_weight.grad.to("cpu") |
768 | 772 | torch.testing.assert_close(true_cpu_grads, res_grads) |
769 | 773 |
|
770 | | - @needs_cuda |
| 774 | + @pytest.mark.parametrize('device', cpu_and_gpu()) |
771 | 775 | @pytest.mark.parametrize('batch_sz', (0, 33)) |
772 | 776 | @pytest.mark.parametrize('dtype', (torch.float, torch.half)) |
773 | | - def test_autocast(self, batch_sz, dtype): |
774 | | - with torch.cuda.amp.autocast(): |
775 | | - self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) |
| 777 | + def test_autocast(self, device, batch_sz, dtype): |
| 778 | + cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast |
| 779 | + with cm(): |
| 780 | + self.test_forward(torch.device(device), contiguous=False, batch_sz=batch_sz, dtype=dtype) |
776 | 781 |
|
777 | 782 | def test_forward_scriptability(self): |
778 | 783 | # Non-regression test for https://github.com/pytorch/vision/issues/4078 |
|
0 commit comments