Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
d0d8271
make skipXPU work
daisyden May 11, 2024
c791db9
enabled torch-xpu ops in op_db
daisyden May 13, 2024
f5cbd50
clean up code
daisyden May 13, 2024
fa6c8ae
refine the xpu ops switch method
daisyden May 22, 2024
22ad9f8
remove unused comments
daisyden May 22, 2024
d879083
added xpu_op_db.yaml to enable xpu op UT. Refined onlyOn interface to…
daisyden May 26, 2024
d7af50b
remove unused skip_device field in opInfo
daisyden May 26, 2024
d99c6a9
code clean up
daisyden May 29, 2024
0a22770
further cleanup
daisyden May 29, 2024
4705581
instead of do unittest.skip for xpu unsupported op in OpInfo, move it…
daisyden May 31, 2024
1c4f479
merge from stock pytorch
daisyden May 31, 2024
d9513ec
refine function naming
daisyden May 31, 2024
70283c5
skip unsupported xpu test by two means: define the unsupported dtypes…
daisyden Jun 2, 2024
9c3b81d
update skipOps decorators for XPU
daisyden Jun 3, 2024
fd3d73b
update according to comments
daisyden Jun 3, 2024
f53f4d3
pass lintrunner
daisyden Jun 4, 2024
5a2382f
skip mul and where
daisyden Jun 4, 2024
b1c0fff
rollback a change in ops()
daisyden Jun 6, 2024
775db6e
remove unused comments
daisyden Jun 6, 2024
10fd731
fix lint issue
daisyden Jun 7, 2024
efbea1f
Merge branch 'daisyden/baseline' of https://github.com/daisyden/pytor…
daisyden Jun 12, 2024
557aa73
disable bernoulli
daisyden Jun 13, 2024
c6965a1
disable nn.funcitonal.embedding
daisyden Jun 14, 2024
9c6a1cf
rebased
daisyden Jun 24, 2024
5607564
refine format
daisyden Jun 24, 2024
f80a486
lint format
daisyden Jun 24, 2024
4c7ac90
fix an mkldnn blas error message
daisyden Jun 26, 2024
681f20e
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Jul 3, 2024
4c3828d
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Jul 4, 2024
c1d5299
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Jul 5, 2024
0b24979
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Jul 9, 2024
5cfb034
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Jul 14, 2024
b7e7b0a
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Jul 19, 2024
1dd0f27
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Jul 22, 2024
a3c3b03
reverted skipOps interface and renamed get_backend_op_dict
daisyden Jul 23, 2024
aa26c03
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Sep 29, 2024
0b3d57b
retrigger checks
daisyden Oct 8, 2024
76d2cee
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
huaiyuzh Oct 10, 2024
d70825c
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Oct 18, 2024
5afdd8e
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Oct 24, 2024
af96524
unified the UT infrastructure api to support diffrent platform, passe…
daisyden Oct 29, 2024
4e2d28c
fix GPU_TYPES import
daisyden Nov 15, 2024
af203f2
fix GPU_TYPES import
daisyden Nov 15, 2024
dccc9f4
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Nov 15, 2024
a775908
fix typo
daisyden Nov 18, 2024
9c591c6
remove xpu backend specific code in
daisyden Nov 18, 2024
7be04c0
Merge branch 'daisyden/stock_pt' of https://github.com/daisyden/pytor…
daisyden Nov 18, 2024
af351f0
further remove xpu backend specific skips
daisyden Nov 18, 2024
8381e18
further remove xpu backend specific skips
daisyden Nov 18, 2024
5f5d50f
further remove xpu backend specific skips
daisyden Nov 18, 2024
bd2f0b8
remove yaml dependency
daisyden Nov 18, 2024
6f513cb
fix HAS_GPU import issue
daisyden Nov 18, 2024
42832a0
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Nov 18, 2024
d1757a0
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Nov 19, 2024
b29b593
remove unused function get_backend_ops as design changes, return torc…
daisyden Nov 21, 2024
b87ca3f
Merge remote-tracking branch 'origin/daisyden/baseline' into daisyden…
daisyden Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 45 additions & 27 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@
from torch.testing._internal import composite_compliance, opinfo
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import (
any_common_cpu_device_one,
deviceCountAtLeast,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
onlyGPU,
onlyNativeDeviceTypesAnd,
OpDTypes,
ops,
skipMeta,
skipXPU,
is_gpu_device,
)
from torch.testing._internal.common_dtype import (
all_types_and_complex_and,
Expand Down Expand Up @@ -74,6 +77,9 @@
TEST_WITH_UBSAN,
TestCase,
unMarkDynamoStrictTest,
GPU_TYPES,
HAS_GPU,
get_gpu_autocast,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -114,7 +120,7 @@ def reduction_dtype_filter(op):


# Create a list of operators that are a subset of _ref_test_ops but don't have a
# numpy ref to compare them too, If both CPU and CUDA are compared to numpy
# numpy ref to compare them too, If both CPU and GPU are compared to numpy
# then they do not need to be compared to each other
_ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None]

Expand Down Expand Up @@ -247,22 +253,22 @@ def tearDownClass(cls):

assert len(filtered_ops) == 0, err_msg

# Validates that each OpInfo works correctly on different CUDA devices
@onlyCUDA
# Validates that each OpInfo works correctly on different GPU devices
@onlyGPU
@deviceCountAtLeast(2)
@ops(op_db, allowed_dtypes=(torch.float32, torch.long))
def test_multiple_devices(self, devices, dtype, op):
for cuda_device_str in devices:
cuda_device = torch.device(cuda_device_str)
for gpu_device_str in devices:
gpu_device = torch.device(gpu_device_str)
# NOTE: only tests on first sample
samples = op.sample_inputs(cuda_device, dtype)
samples = op.sample_inputs(gpu_device, dtype)
sample = first_sample(self, samples)
result = op(sample.input, *sample.args, **sample.kwargs)

if isinstance(result, torch.Tensor):
self.assertTrue(result.device == cuda_device)
self.assertTrue(result.device == gpu_device)
elif is_iterable_of_tensors(result):
self.assertTrue(all(t.device == cuda_device for t in result))
self.assertTrue(all(t.device == gpu_device for t in result))
else:
self.skipTest(
"Skipped! Only supports single tensor or iterable of tensor outputs."
Expand Down Expand Up @@ -367,7 +373,7 @@ def test_numpy_ref(self, device, dtype, op):
and op.formatted_name
in ("signal_windows_exponential", "signal_windows_bartlett")
and dtype == torch.float64
and "cuda" in device
and is_gpu_device(device)
or "cpu" in device
): # noqa: E121
raise unittest.SkipTest("XXX: raises tensor-likes are not close.")
Expand All @@ -380,10 +386,10 @@ def test_numpy_ref(self, device, dtype, op):
)

# Tests that the cpu and gpu results are consistent
@onlyCUDA
@onlyGPU
@suppress_warnings
@slowTest
@ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one)
@ops(_ops_and_refs_with_no_numpy_ref, dtypes=any_common_cpu_device_one())
def test_compare_cpu(self, device, dtype, op):
def to_cpu(arg):
if isinstance(arg, torch.Tensor):
Expand All @@ -394,20 +400,20 @@ def to_cpu(arg):

for sample in samples:
cpu_sample = sample.transform(to_cpu)
cuda_results = op(sample.input, *sample.args, **sample.kwargs)
gpu_results = op(sample.input, *sample.args, **sample.kwargs)
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)

# output_process_fn_grad has a very unfortunate name
# We use this function in linalg extensively to postprocess the inputs of functions
# that are not completely well-defined. Think svd and muliplying the singular vectors by -1.
# CPU and CUDA implementations of the SVD can return valid SVDs that are different.
# CPU and GPU implementations of the SVD can return valid SVDs that are different.
# We use this function to compare them.
cuda_results = sample.output_process_fn_grad(cuda_results)
gpu_results = sample.output_process_fn_grad(gpu_results)
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)

# Lower tolerance because we are running this as a `@slowTest`
# Don't want the periodic tests to fail frequently
self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3)
self.assertEqual(gpu_results, cpu_results, atol=1e-3, rtol=1e-3)

# Tests that experimental Python References can propagate shape, dtype,
# and device metadata properly.
Expand Down Expand Up @@ -637,7 +643,7 @@ def test_python_ref_torch_fallback(self, device, dtype, op):
self._ref_test_helper(contextlib.nullcontext, device, dtype, op)

@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyCUDA
@onlyGPU
@ops(python_ref_db)
@parametrize("executor", ["aten"])
@skipIfTorchInductor("Takes too long for inductor")
Expand Down Expand Up @@ -677,6 +683,7 @@ def test_errors(self, device, op):
out = op(si.input, *si.args, **si.kwargs)
self.assertFalse(isinstance(out, type(NotImplemented)))

@skipXPU
@skipMeta
@onlyNativeDeviceTypesAnd(["hpu"])
@ops(
Expand Down Expand Up @@ -879,10 +886,13 @@ def _extract_strides(out):
return tuple(t.stride() for t in out)

# Extracts data pointers from a tensor or iterable of tensors into a tuple
# NOTE: only extracts on the CPU and CUDA device types since some
# NOTE: only extracts on the CPU and GPU device types since some
# device types don't have storage
def _extract_data_ptrs(out):
if self.device_type != "cpu" and self.device_type != "cuda":
if (
self.device_type != "cpu"
and self.device_type in GPU_TYPES
):
return ()

if isinstance(out, torch.Tensor):
Expand Down Expand Up @@ -1007,10 +1017,13 @@ def _extract_strides(out):
return tuple(t.stride() for t in out)

# Extracts data pointers from a tensor or iterable of tensors into a tuple
# NOTE: only extracts on the CPU and CUDA device types since some
# NOTE: only extracts on the CPU and GPU device types since some
# device types don't have storage
def _extract_data_ptrs(out):
if self.device_type != "cpu" and self.device_type != "cuda":
if (
self.device_type != "cpu"
and self.device_type not in GPU_TYPES
):
return ()

if isinstance(out, torch.Tensor):
Expand Down Expand Up @@ -1085,8 +1098,8 @@ def _case_two_transform(t):
wrong_device = None
if torch.device(device).type != "cpu":
wrong_device = "cpu"
elif torch.cuda.is_available():
wrong_device = "cuda"
elif HAS_GPU:
wrong_device = GPU_TYPE

factory_fn_msg = (
"\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its "
Expand Down Expand Up @@ -1469,8 +1482,9 @@ def convert_boolean_tensors(x):
self.assertEqual(expect, actual)

# Validates that each OpInfo specifies its forward and backward dtypes
# correctly for CPU and CUDA devices
# correctly for CPU and GPU devices
@skipMeta
@skipXPU
@onlyNativeDeviceTypesAnd(["hpu"])
@ops(ops_and_refs, dtypes=OpDTypes.none)
def test_dtypes(self, device, op):
Expand Down Expand Up @@ -1824,6 +1838,7 @@ def test_forward_ad(self, device, dtype, op):
op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual
)

@skipXPU
@ops(op_db, allowed_dtypes=(torch.float,))
def test_cow_input(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
Expand Down Expand Up @@ -2140,6 +2155,7 @@ def clone_and_perform_view(input, **kwargs):
self.assertEqual(tensor.grad, cloned1_tensor.grad)

@ops(ops_and_refs, allowed_dtypes=(torch.cfloat,))
@skipXPU
def test_conj_view(self, device, dtype, op):
if not op.test_conjugated_samples:
self.skipTest("Operation doesn't support conjugated inputs.")
Expand Down Expand Up @@ -2181,6 +2197,7 @@ def test_neg_view(self, device, dtype, op):
)

@ops(ops_and_refs, allowed_dtypes=(torch.cdouble,))
@skipXPU
def test_neg_conj_view(self, device, dtype, op):
if not op.test_neg_view:
self.skipTest("Operation not tested with tensors with negative bit.")
Expand Down Expand Up @@ -2520,6 +2537,7 @@ def test_refs_are_in_decomp_table(self, op):
# TODO: investigate/fix
fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
fake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"}
fake_autocast_device_skips["xpu"] = {"linalg.pinv", "pinverse"}


dynamic_output_op_tests = (
Expand Down Expand Up @@ -2806,23 +2824,23 @@ def _test_fake_crossref_helper(self, device, dtype, op, context):
except torch._subclasses.fake_tensor.UnsupportedOperatorException:
pass

@onlyCUDA
@onlyGPU
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
@skipOps(
"TestFakeTensor", "test_fake_crossref_backward_no_amp", fake_backward_xfails
)
def test_fake_crossref_backward_no_amp(self, device, dtype, op):
self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext)

@onlyCUDA
@onlyGPU
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
@skipOps(
"TestFakeTensor",
"test_fake_crossref_backward_amp",
fake_backward_xfails | fake_autocast_backward_xfails,
)
def test_fake_crossref_backward_amp(self, device, dtype, op):
self._test_fake_crossref_helper(device, dtype, op, torch.cuda.amp.autocast)
self._test_fake_crossref_helper(device, dtype, op, get_gpu_autocast())

@ops([op for op in ops_and_refs if op.is_factory_function])
def test_strided_layout(self, device, dtype, op):
Expand Down
53 changes: 47 additions & 6 deletions torch/testing/_internal/common_device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
TEST_WITH_UBSAN,
TEST_XPU,
TestCase,
GPU_TYPE,
GPU_TYPES,
)


Expand Down Expand Up @@ -823,7 +825,7 @@ def get_desired_device_type_test_bases(
test_bases = device_type_test_bases.copy()
if allow_mps and TEST_MPS and MPSTestBase not in test_bases:
test_bases.append(MPSTestBase)
if allow_xpu and TEST_XPU and XPUTestBase not in test_bases:
if (allow_xpu or only_for == "xpu") and TEST_XPU and XPUTestBase not in test_bases:
test_bases.append(XPUTestBase)
if TEST_HPU and HPUTestBase not in test_bases:
test_bases.append(HPUTestBase)
Expand Down Expand Up @@ -992,6 +994,9 @@ class OpDTypes(Enum):
any_common_cpu_cuda_one = (
6 # Test precisely one supported dtype that is common to both cuda and cpu
)
any_common_cpu_xpu_one = (
7 # Test precisely one supported dtype that is common to both xpu and cpu
)


# Arbitrary order
Expand Down Expand Up @@ -1122,7 +1127,15 @@ def _parametrize_test(self, test, generic_cls, device_cls):
}
else:
dtypes = {}

elif self.opinfo_dtypes == OpDTypes.any_common_cpu_xpu_one:
# Tries to pick a dtype that supports both CPU and CUDA
supported = set(op.dtypes).intersection(op.dtypesIfXPU)
if supported:
dtypes = {
next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported)
}
else:
dtypes = {}
elif self.opinfo_dtypes == OpDTypes.none:
dtypes = {None}
else:
Expand Down Expand Up @@ -1396,14 +1409,19 @@ def efail_fn(slf, *args, **kwargs):


class onlyOn:
def __init__(self, device_type):
self.device_type = device_type
def __init__(self, device_type: Union[str, List[str]]):
self.device_types = []
if isinstance(device_type, str):
self.device_types.append(device_type)
else:
assert isinstance(device_type, list)
self.device_types = device_type

def __call__(self, fn):
@wraps(fn)
def only_fn(slf, *args, **kwargs):
if self.device_type != slf.device_type:
reason = f"Only runs on {self.device_type}"
if slf.device_type not in self.device_types:
reason = f"Only runs on {self.device_types}"
raise unittest.SkipTest(reason)

return fn(slf, *args, **kwargs)
Expand Down Expand Up @@ -1625,6 +1643,10 @@ def onlyHPU(fn):
return onlyOn("hpu")(fn)


def onlyGPU(fn):
return onlyOn(GPU_TYPES)(fn)


def onlyPRIVATEUSE1(fn):
device_type = torch._C._get_privateuse1_backend_name()
device_mod = getattr(torch, device_type, None)
Expand All @@ -1646,6 +1668,9 @@ def only_fn(self, *args, **kwargs):
return only_fn





def disablecuDNN(fn):
@wraps(fn)
def disable_cudnn(self, *args, **kwargs):
Expand Down Expand Up @@ -1943,6 +1968,10 @@ def skipMeta(fn):
return skipMetaIf(True, "test doesn't work with meta tensors")(fn)


def skipXPU(fn):
return skipXPUIf(True, "test doesn't work with XPU tensors")(fn)


def skipXLA(fn):
return skipXLAIf(True, "Marked as skipped for XLA")(fn)

Expand Down Expand Up @@ -1971,3 +2000,15 @@ def get_all_device_types() -> List[str]:
and torch.cuda.get_device_capability() >= (8, 0),
"Requires CUDA and Triton",
)


def any_common_cpu_device_one():
return (
OpDTypes.any_common_cpu_xpu_one
if TEST_XPU
else OpDTypes.any_common_cpu_cuda_one
)


def is_gpu_device(devices: List[str]):
return "cuda" in devices or "xpu" in devices
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24767,4 +24767,4 @@ def skipOps(test_case_name, base_test_name, to_skip):
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
return wrapped
Loading