1- import dataclasses
21import functools
32import itertools
43import math
5- from collections import defaultdict
6- from typing import Any , Callable , Dict , Iterable , List , Optional , Sequence , Tuple
74
85import numpy as np
96import pytest
107import torch .testing
118import torchvision .ops
129import torchvision .prototype .transforms .functional as F
13-
14- from _pytest .mark .structures import MarkDecorator
1510from common_utils import cycle_over
1611from datasets_utils import combinations_grid
1712from prototype_common_utils import (
1813 ArgsKwargs ,
14+ InfoBase ,
1915 make_bounding_box_loaders ,
2016 make_image_loader ,
2117 make_image_loaders ,
2218 make_mask_loaders ,
2319 make_video_loaders ,
20+ mark_framework_limitation ,
21+ TestMark ,
2422 VALID_EXTRA_DIMS ,
2523)
2624from torchvision .prototype import features
2927__all__ = ["KernelInfo" , "KERNEL_INFOS" ]
3028
3129
32- TestID = Tuple [Optional [str ], str ]
33-
34-
35- @dataclasses .dataclass
36- class TestMark :
37- test_id : TestID
38- mark : MarkDecorator
39- condition : Callable [[ArgsKwargs ], bool ] = lambda args_kwargs : True
40-
41-
42- @dataclasses .dataclass
43- class KernelInfo :
44- kernel : Callable
45- # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should
46- # not include extensive parameter combinations to keep to overall test count moderate.
47- sample_inputs_fn : Callable [[], Iterable [ArgsKwargs ]]
48- # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
49- # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
50- kernel_name : str = dataclasses .field (default = None )
51- # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
52- # tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
53- # inside the function. It should return a tensor or to be more precise an object that can be compared to a
54- # tensor by `assert_close`. If omitted, no reference test will be performed.
55- reference_fn : Optional [Callable ] = None
56- # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
57- # values to be tested. If not specified, `sample_inputs_fn` will be used.
58- reference_inputs_fn : Optional [Callable [[], Iterable [ArgsKwargs ]]] = None
59- # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
60- closeness_kwargs : Dict [str , Any ] = dataclasses .field (default_factory = dict )
61- test_marks : Sequence [TestMark ] = dataclasses .field (default_factory = list )
62- _test_marks_map : Dict [str , List [TestMark ]] = dataclasses .field (default = None , init = False )
63-
64- def __post_init__ (self ):
65- self .kernel_name = self .kernel_name or self .kernel .__name__
66- self .reference_inputs_fn = self .reference_inputs_fn or self .sample_inputs_fn
67-
68- test_marks_map = defaultdict (list )
69- for test_mark in self .test_marks :
70- test_marks_map [test_mark .test_id ].append (test_mark )
71- self ._test_marks_map = dict (test_marks_map )
72-
73- def get_marks (self , test_id , args_kwargs ):
74- return [
75- test_mark .mark for test_mark in self ._test_marks_map .get (test_id , []) if test_mark .condition (args_kwargs )
76- ]
30+ class KernelInfo (InfoBase ):
31+ def __init__ (
32+ self ,
33+ kernel ,
34+ * ,
35+ # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
36+ # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
37+ kernel_name = None ,
38+ # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
39+ # should not include extensive parameter combinations to keep to overall test count moderate.
40+ sample_inputs_fn ,
41+ # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
42+ # take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
43+ # happen inside the function. It should return a tensor or to be more precise an object that can be compared to
44+ # a tensor by `assert_close`. If omitted, no reference test will be performed.
45+ reference_fn = None ,
46+ # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
47+ # values to be tested. If not specified, `sample_inputs_fn` will be used.
48+ reference_inputs_fn = None ,
49+ # See InfoBase
50+ test_marks = None ,
51+ # See InfoBase
52+ closeness_kwargs = None ,
53+ ):
54+ super ().__init__ (id = kernel_name or kernel .__name__ , test_marks = test_marks , closeness_kwargs = closeness_kwargs )
55+ self .kernel = kernel
56+ self .sample_inputs_fn = sample_inputs_fn
57+ self .reference_fn = reference_fn
58+ self .reference_inputs_fn = reference_inputs_fn
7759
7860
7961DEFAULT_IMAGE_CLOSENESS_KWARGS = dict (
@@ -97,16 +79,6 @@ def wrapper(image_tensor, *other_args, **kwargs):
9779 return wrapper
9880
9981
100- def mark_framework_limitation (test_id , reason ):
101- # The purpose of this function is to have a single entry point for skip marks that are only there, because the test
102- # framework cannot handle the kernel in general or a specific parameter combination.
103- # As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
104- # still justified.
105- # We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
106- # we are wasting CI resources for no reason for most of the time.
107- return TestMark (test_id , pytest .mark .skip (reason = reason ))
108-
109-
11082def xfail_jit_python_scalar_arg (name , * , reason = None ):
11183 reason = reason or f"Python scalar int or float for `{ name } ` is not supported when scripting"
11284 return TestMark (
0 commit comments