Skip to content

Commit f4ab3e7

Browse files
authored
[FBcode->GH] Better logic for ignoring CPU tests on GPU CI machines (#4025) (#4062)
1 parent 686ff59 commit f4ab3e7

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

test/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ def pytest_configure(config):
88
config.addinivalue_line(
99
"markers", "needs_cuda: mark for tests that rely on a CUDA device"
1010
)
11+
config.addinivalue_line(
12+
"markers", "dont_collect: mark for tests that should not be collected"
13+
)
1114

1215

1316
def pytest_collection_modifyitems(items):
@@ -47,6 +50,10 @@ def pytest_collection_modifyitems(items):
4750
# to run the CPU-only tests.
4851
item.add_marker(pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG))
4952

53+
if item.get_closest_marker('dont_collect') is not None:
54+
# currently, this is only used for some tests we're sure we dont want to run on fbcode
55+
continue
56+
5057
out_items.append(item)
5158

5259
items[:] = out_items

test/test_image.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,18 @@ def test_encode_jpeg_errors():
358358
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
359359

360360

361+
def _collect_if(cond):
362+
# TODO: remove this once test_encode_jpeg_reference and test_write_jpeg_reference
363+
# are removed
364+
def _inner(test_func):
365+
if cond:
366+
return test_func
367+
else:
368+
return pytest.mark.dont_collect(test_func)
369+
return _inner
370+
371+
372+
@_collect_if(cond=IS_WINDOWS)
361373
@pytest.mark.parametrize('img_path', [
362374
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
363375
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
@@ -389,6 +401,7 @@ def test_encode_jpeg_reference(img_path):
389401
assert_equal(jpeg_bytes, pil_bytes)
390402

391403

404+
@_collect_if(cond=IS_WINDOWS)
392405
@pytest.mark.parametrize('img_path', [
393406
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
394407
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")

0 commit comments

Comments
 (0)