Skip to content

Commit 508b289

Browse files
authored
Better logic for ignoring CPU tests on GPU CI machines (#4025)
1 parent cd18188 commit 508b289

File tree

6 files changed

+59
-111
lines changed

6 files changed

+59
-111
lines changed

test/common_utils.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -259,58 +259,12 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
259259

260260
def cpu_and_gpu():
261261
import pytest # noqa
262-
263-
# ignore CPU tests in RE as they're already covered by another contbuild
264-
# also ignore CPU tests in CircleCI machines that have a GPU: these tests
265-
# are run on CPU-only machines already.
266-
if IN_RE_WORKER:
267-
devices = []
268-
else:
269-
if IN_CIRCLE_CI and torch.cuda.is_available():
270-
mark = pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG)
271-
else:
272-
mark = ()
273-
devices = [pytest.param('cpu', marks=mark)]
274-
275-
if torch.cuda.is_available():
276-
cuda_marks = ()
277-
elif IN_FBCODE:
278-
# Dont collect cuda tests on fbcode if the machine doesnt have a GPU
279-
# This avoids skipping the tests. More robust would be to detect if
280-
# we're in sancastle instead of fbcode?
281-
cuda_marks = pytest.mark.dont_collect()
282-
else:
283-
cuda_marks = pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)
284-
285-
devices.append(pytest.param('cuda', marks=cuda_marks))
286-
287-
return devices
262+
return ('cpu', pytest.param('cuda', marks=pytest.mark.needs_cuda))
288263

289264

290265
def needs_cuda(test_func):
291266
import pytest # noqa
292-
293-
if IN_FBCODE and not IN_RE_WORKER:
294-
# We don't want to skip in fbcode, so we just don't collect
295-
# TODO: slightly more robust way would be to detect if we're in a sandcastle instance
296-
# so that the test will still be collected (and skipped) in the devvms.
297-
return pytest.mark.dont_collect(test_func)
298-
elif torch.cuda.is_available():
299-
return test_func
300-
else:
301-
return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func)
302-
303-
304-
def cpu_only(test_func):
305-
import pytest # noqa
306-
307-
if IN_RE_WORKER:
308-
# The assumption is that all RE workers have GPUs.
309-
return pytest.mark.dont_collect(test_func)
310-
elif IN_CIRCLE_CI and torch.cuda.is_available():
311-
return pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG)(test_func)
312-
else:
313-
return test_func
267+
return pytest.mark.needs_cuda(test_func)
314268

315269

316270
def _create_data(height=3, width=3, channels=3, device="cpu"):

test/conftest.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,52 @@
1+
from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG
2+
import torch
3+
import pytest
4+
5+
16
def pytest_configure(config):
27
# register an additional marker (see pytest_collection_modifyitems)
38
config.addinivalue_line(
4-
"markers", "dont_collect: marks a test that should not be collected (avoids skipping it)"
9+
"markers", "needs_cuda: mark for tests that rely on a CUDA device"
510
)
611

712

813
def pytest_collection_modifyitems(items):
914
# This hook is called by pytest after it has collected the tests (google its name!)
10-
# We can ignore some tests as we see fit here. In particular we ignore the tests that
11-
# we have marked with the custom 'dont_collect' mark. This avoids skipping the tests,
12-
# since the internal fb infra doesn't like skipping tests.
13-
to_keep = [item for item in items if item.get_closest_marker('dont_collect') is None]
14-
items[:] = to_keep
15+
# We can ignore some tests as we see fit here, or add marks, such as a skip mark.
16+
17+
out_items = []
18+
for item in items:
19+
# The needs_cuda mark will exist if the test was explicitely decorated with
20+
# the @needs_cuda decorator. It will also exist if it was parametrized with a
21+
# parameter that has the mark: for example if a test is parametrized with
22+
# @pytest.mark.parametrize('device', cpu_and_gpu())
23+
# the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark,
24+
# and the ones with device == 'cpu' won't have the mark.
25+
needs_cuda = item.get_closest_marker('needs_cuda') is not None
26+
27+
if needs_cuda and not torch.cuda.is_available():
28+
# In general, we skip cuda tests on machines without a GPU
29+
# There are special cases though, see below
30+
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))
31+
32+
if IN_FBCODE:
33+
# fbcode doesn't like skipping tests, so instead we just don't collect the test
34+
# so that they don't even "exist", hence the continue statements.
35+
if not needs_cuda and IN_RE_WORKER:
36+
# The RE workers are the machines with GPU, we don't want them to run CPU-only tests.
37+
continue
38+
if needs_cuda and not torch.cuda.is_available():
39+
# On the test machines without a GPU, we want to ignore the tests that need cuda.
40+
# TODO: something more robust would be to do that only in a sandcastle instance,
41+
# so that we can still see the test being skipped when testing locally from a devvm
42+
continue
43+
elif IN_CIRCLE_CI:
44+
# Here we're not in fbcode, so we can safely collect and skip tests.
45+
if not needs_cuda and torch.cuda.is_available():
46+
# Similar to what happens in RE workers: we don't need the CircleCI GPU machines
47+
# to run the CPU-only tests.
48+
item.add_marker(pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG))
49+
50+
out_items.append(item)
51+
52+
items[:] = out_items

test/test_image.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from PIL import Image
1111
import torchvision.transforms.functional as F
12-
from common_utils import get_tmp_dir, needs_cuda, cpu_only
12+
from common_utils import get_tmp_dir, needs_cuda
1313
from _assert_utils import assert_equal
1414

1515
from torchvision.io.image import (
@@ -335,7 +335,6 @@ def test_decode_jpeg_cuda_errors():
335335
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')
336336

337337

338-
@cpu_only
339338
def test_encode_jpeg_errors():
340339

341340
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
@@ -359,26 +358,13 @@ def test_encode_jpeg_errors():
359358
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
360359

361360

362-
def _collect_if(cond):
363-
# TODO: remove this once test_encode_jpeg_windows and test_write_jpeg_windows
364-
# are removed
365-
def _inner(test_func):
366-
if cond:
367-
return test_func
368-
else:
369-
return pytest.mark.dont_collect(test_func)
370-
return _inner
371-
372-
373-
@cpu_only
374-
@_collect_if(cond=IS_WINDOWS)
375361
@pytest.mark.parametrize('img_path', [
376362
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
377363
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
378364
])
379-
def test_encode_jpeg_windows(img_path):
365+
def test_encode_jpeg_reference(img_path):
380366
# This test is *wrong*.
381-
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
367+
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it
382368
# starts encoding the torchvision version from an image that comes from
383369
# decode_jpeg, which can yield different results from pil.decode (see
384370
# test_decode... which uses a high tolerance).
@@ -403,14 +389,12 @@ def test_encode_jpeg_windows(img_path):
403389
assert_equal(jpeg_bytes, pil_bytes)
404390

405391

406-
@cpu_only
407-
@_collect_if(cond=IS_WINDOWS)
408392
@pytest.mark.parametrize('img_path', [
409393
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
410394
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
411395
])
412-
def test_write_jpeg_windows(img_path):
413-
# FIXME: Remove this eventually, see test_encode_jpeg_windows
396+
def test_write_jpeg_reference(img_path):
397+
# FIXME: Remove this eventually, see test_encode_jpeg_reference
414398
with get_tmp_dir() as d:
415399
data = read_file(img_path)
416400
img = decode_jpeg(data)
@@ -433,8 +417,9 @@ def test_write_jpeg_windows(img_path):
433417
assert_equal(torch_bytes, pil_bytes)
434418

435419

436-
@cpu_only
437-
@_collect_if(cond=not IS_WINDOWS)
420+
@pytest.mark.skipif(IS_WINDOWS, reason=(
421+
'this test fails on windows because PIL uses libjpeg-turbo on windows'
422+
))
438423
@pytest.mark.parametrize('img_path', [
439424
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
440425
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
@@ -455,8 +440,9 @@ def test_encode_jpeg(img_path):
455440
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
456441

457442

458-
@cpu_only
459-
@_collect_if(cond=not IS_WINDOWS)
443+
@pytest.mark.skipif(IS_WINDOWS, reason=(
444+
'this test fails on windows because PIL uses libjpeg-turbo on windows'
445+
))
460446
@pytest.mark.parametrize('img_path', [
461447
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
462448
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")

test/test_models.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import io
33
import sys
4-
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda, cpu_only
4+
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda
55
from _utils_internal import get_relative_path
66
from collections import OrderedDict
77
import functools
@@ -234,7 +234,6 @@ def _make_sliced_model(model, stop_layer):
234234
return new_model
235235

236236

237-
@cpu_only
238237
@pytest.mark.parametrize('model_name', ['densenet121', 'densenet169', 'densenet201', 'densenet161'])
239238
def test_memory_efficient_densenet(model_name):
240239
input_shape = (1, 3, 300, 300)
@@ -257,7 +256,6 @@ def test_memory_efficient_densenet(model_name):
257256
torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5)
258257

259258

260-
@cpu_only
261259
@pytest.mark.parametrize('dilate_layer_2', (True, False))
262260
@pytest.mark.parametrize('dilate_layer_3', (True, False))
263261
@pytest.mark.parametrize('dilate_layer_4', (True, False))
@@ -272,7 +270,6 @@ def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4):
272270
assert out.shape == (1, 2048, 7 * f, 7 * f)
273271

274272

275-
@cpu_only
276273
def test_mobilenet_v2_residual_setting():
277274
model = models.__dict__["mobilenet_v2"](inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]])
278275
model.eval()
@@ -281,7 +278,6 @@ def test_mobilenet_v2_residual_setting():
281278
assert out.shape[-1] == 1000
282279

283280

284-
@cpu_only
285281
@pytest.mark.parametrize('model_name', ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"])
286282
def test_mobilenet_norm_layer(model_name):
287283
model = models.__dict__[model_name]()
@@ -295,7 +291,6 @@ def get_gn(num_channels):
295291
assert any(isinstance(x, nn.GroupNorm) for x in model.modules())
296292

297293

298-
@cpu_only
299294
def test_inception_v3_eval():
300295
# replacement for models.inception_v3(pretrained=True) that does not download weights
301296
kwargs = {}
@@ -311,7 +306,6 @@ def test_inception_v3_eval():
311306
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
312307

313308

314-
@cpu_only
315309
def test_fasterrcnn_double():
316310
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
317311
model.double()
@@ -327,7 +321,6 @@ def test_fasterrcnn_double():
327321
assert "labels" in out[0]
328322

329323

330-
@cpu_only
331324
def test_googlenet_eval():
332325
# replacement for models.googlenet(pretrained=True) that does not download weights
333326
kwargs = {}
@@ -376,7 +369,6 @@ def checkOut(out):
376369
checkOut(out_cpu)
377370

378371

379-
@cpu_only
380372
def test_generalizedrcnn_transform_repr():
381373

382374
min_size, max_size = 224, 299
@@ -573,7 +565,6 @@ def compute_mean_std(tensor):
573565
pytest.skip(msg)
574566

575567

576-
@cpu_only
577568
@pytest.mark.parametrize('model_name', get_available_detection_models())
578569
def test_detection_model_validation(model_name):
579570
set_rng_seed(0)

0 commit comments

Comments
 (0)