Skip to content

Commit 1202b33

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Add warnings checks for v2 namespaces and deprecated files (#7288)
Summary: Co-authored-by: Philip Meier <[email protected]> Reviewed By: vmoens Differential Revision: D44416588 fbshipit-source-id: 745e518fcca11abee87de3ced91571206c2e13fb
1 parent 2fb8e37 commit 1202b33

File tree

4 files changed

+106
-2
lines changed

4 files changed

+106
-2
lines changed

test/common_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import pathlib
99
import random
1010
import shutil
11+
import sys
1112
import tempfile
1213
from collections import defaultdict
14+
from subprocess import CalledProcessError, check_output, STDOUT
1315
from typing import Callable, Sequence, Tuple, Union
1416

1517
import numpy as np
@@ -838,3 +840,22 @@ def get_closeness_kwargs(self, test_id, *, dtype, device):
838840
if isinstance(device, torch.device):
839841
device = device.type
840842
return self.closeness_kwargs.get((test_id, dtype, device), dict())
843+
844+
845+
def assert_run_python_script(source_code):
846+
"""Utility to check assertions in an independent Python subprocess.
847+
The script provided in the source code should return 0 and not print
848+
anything on stderr or stdout. Taken from scikit-learn test utils.
849+
source_code (str): The Python source code to execute.
850+
"""
851+
with tempfile.NamedTemporaryFile(mode="wb") as f:
852+
f.write(source_code.encode())
853+
f.flush()
854+
855+
cmd = [sys.executable, f.name]
856+
try:
857+
out = check_output(cmd, stderr=STDOUT)
858+
except CalledProcessError as e:
859+
raise RuntimeError(f"script errored with output:\n{e.output.decode()}")
860+
if out != b"":
861+
raise AssertionError(out.decode())

test/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import pytest
55
import torch
66
import torchvision
7-
from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG
87

98

109
torchvision.disable_beta_transforms_warning()
1110

11+
from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG
12+
1213

1314
def pytest_configure(config):
1415
# register an additional marker (see pytest_collection_modifyitems)

test/test_transforms.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import random
44
import re
5+
import textwrap
56
import warnings
67
from functools import partial
78

@@ -24,7 +25,7 @@
2425
except ImportError:
2526
stats = None
2627

27-
from common_utils import assert_equal, cycle_over, float_dtypes, int_dtypes
28+
from common_utils import assert_equal, assert_run_python_script, cycle_over, float_dtypes, int_dtypes
2829

2930

3031
GRACE_HOPPER = get_file_path_2(
@@ -2266,5 +2267,35 @@ def test_random_grayscale_with_grayscale_input():
22662267
torch.testing.assert_close(F.pil_to_tensor(output_pil), image_tensor)
22672268

22682269

2270+
# TODO: remove in 0.17 when we can delete functional_pil.py and functional_tensor.py
2271+
@pytest.mark.parametrize(
2272+
"import_statement",
2273+
(
2274+
"from torchvision.transforms import functional_pil",
2275+
"from torchvision.transforms import functional_tensor",
2276+
"from torchvision.transforms.functional_tensor import resize",
2277+
"from torchvision.transforms.functional_pil import resize",
2278+
),
2279+
)
2280+
@pytest.mark.parametrize("from_private", (True, False))
2281+
def test_functional_deprecation_warning(import_statement, from_private):
2282+
if from_private:
2283+
import_statement = import_statement.replace("functional", "_functional")
2284+
source = f"""
2285+
import warnings
2286+
2287+
with warnings.catch_warnings():
2288+
warnings.simplefilter("error")
2289+
{import_statement}
2290+
"""
2291+
else:
2292+
source = f"""
2293+
import pytest
2294+
with pytest.warns(UserWarning, match="removed in 0.17"):
2295+
{import_statement}
2296+
"""
2297+
assert_run_python_script(textwrap.dedent(source))
2298+
2299+
22692300
if __name__ == "__main__":
22702301
pytest.main([__file__])

test/test_transforms_v2.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pathlib
33
import random
44
import re
5+
import textwrap
56
import warnings
67
from collections import defaultdict
78

@@ -14,6 +15,7 @@
1415

1516
from common_utils import (
1617
assert_equal,
18+
assert_run_python_script,
1719
cpu_and_gpu,
1820
make_bounding_box,
1921
make_bounding_boxes,
@@ -2045,3 +2047,52 @@ def test_sanitize_bounding_boxes_errors():
20452047
)
20462048
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
20472049
transforms.SanitizeBoundingBoxes()(different_sizes)
2050+
2051+
2052+
@pytest.mark.parametrize(
2053+
"import_statement",
2054+
(
2055+
"from torchvision.transforms import v2",
2056+
"import torchvision.transforms.v2",
2057+
"from torchvision.transforms.v2 import Resize",
2058+
"import torchvision.transforms.v2.functional",
2059+
"from torchvision.transforms.v2.functional import resize",
2060+
"from torchvision import datapoints",
2061+
"from torchvision.datapoints import Image",
2062+
"from torchvision.datasets import wrap_dataset_for_transforms_v2",
2063+
),
2064+
)
2065+
@pytest.mark.parametrize("call_disable_warning", (True, False))
2066+
def test_warnings_v2_namespaces(import_statement, call_disable_warning):
2067+
if call_disable_warning:
2068+
source = f"""
2069+
import warnings
2070+
import torchvision
2071+
torchvision.disable_beta_transforms_warning()
2072+
with warnings.catch_warnings():
2073+
warnings.simplefilter("error")
2074+
{import_statement}
2075+
"""
2076+
else:
2077+
source = f"""
2078+
import pytest
2079+
with pytest.warns(UserWarning, match="v2 namespaces are still Beta"):
2080+
{import_statement}
2081+
"""
2082+
assert_run_python_script(textwrap.dedent(source))
2083+
2084+
2085+
def test_no_warnings_v1_namespace():
2086+
source = """
2087+
import warnings
2088+
with warnings.catch_warnings():
2089+
warnings.simplefilter("error")
2090+
import torchvision.transforms
2091+
from torchvision import transforms
2092+
import torchvision.transforms.functional
2093+
from torchvision.transforms import Resize
2094+
from torchvision.transforms.functional import resize
2095+
from torchvision import datasets
2096+
from torchvision.datasets import ImageNet
2097+
"""
2098+
assert_run_python_script(textwrap.dedent(source))

0 commit comments

Comments
 (0)