diff --git a/test/test_ops.py b/test/test_ops.py index 1aee4b306787d..f2df6d8945b4e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -25,14 +25,17 @@ from torch.testing._internal import composite_compliance, opinfo from torch.testing._internal.common_cuda import with_tf32_off from torch.testing._internal.common_device_type import ( + any_common_cpu_device_one, deviceCountAtLeast, instantiate_device_type_tests, onlyCPU, - onlyCUDA, + onlyGPU, onlyNativeDeviceTypesAnd, OpDTypes, ops, skipMeta, + skipXPU, + is_gpu_device, ) from torch.testing._internal.common_dtype import ( all_types_and_complex_and, @@ -74,6 +77,9 @@ TEST_WITH_UBSAN, TestCase, unMarkDynamoStrictTest, + GPU_TYPES, + HAS_GPU, + get_gpu_autocast, ) from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map @@ -114,7 +120,7 @@ def reduction_dtype_filter(op): # Create a list of operators that are a subset of _ref_test_ops but don't have a -# numpy ref to compare them too, If both CPU and CUDA are compared to numpy +# numpy ref to compare them too, If both CPU and GPU are compared to numpy # then they do not need to be compared to each other _ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None] @@ -247,22 +253,22 @@ def tearDownClass(cls): assert len(filtered_ops) == 0, err_msg - # Validates that each OpInfo works correctly on different CUDA devices - @onlyCUDA + # Validates that each OpInfo works correctly on different GPU devices + @onlyGPU @deviceCountAtLeast(2) @ops(op_db, allowed_dtypes=(torch.float32, torch.long)) def test_multiple_devices(self, devices, dtype, op): - for cuda_device_str in devices: - cuda_device = torch.device(cuda_device_str) + for gpu_device_str in devices: + gpu_device = torch.device(gpu_device_str) # NOTE: only tests on first sample - samples = op.sample_inputs(cuda_device, dtype) + samples = op.sample_inputs(gpu_device, dtype) sample = first_sample(self, samples) result = op(sample.input, *sample.args, **sample.kwargs) if isinstance(result, torch.Tensor): - self.assertTrue(result.device == cuda_device) + self.assertTrue(result.device == gpu_device) elif is_iterable_of_tensors(result): - self.assertTrue(all(t.device == cuda_device for t in result)) + self.assertTrue(all(t.device == gpu_device for t in result)) else: self.skipTest( "Skipped! Only supports single tensor or iterable of tensor outputs." @@ -367,7 +373,7 @@ def test_numpy_ref(self, device, dtype, op): and op.formatted_name in ("signal_windows_exponential", "signal_windows_bartlett") and dtype == torch.float64 - and "cuda" in device + and is_gpu_device(device) or "cpu" in device ): # noqa: E121 raise unittest.SkipTest("XXX: raises tensor-likes are not close.") @@ -380,10 +386,10 @@ def test_numpy_ref(self, device, dtype, op): ) # Tests that the cpu and gpu results are consistent - @onlyCUDA + @onlyGPU @suppress_warnings @slowTest - @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) + @ops(_ops_and_refs_with_no_numpy_ref, dtypes=any_common_cpu_device_one()) def test_compare_cpu(self, device, dtype, op): def to_cpu(arg): if isinstance(arg, torch.Tensor): @@ -394,20 +400,20 @@ def to_cpu(arg): for sample in samples: cpu_sample = sample.transform(to_cpu) - cuda_results = op(sample.input, *sample.args, **sample.kwargs) + gpu_results = op(sample.input, *sample.args, **sample.kwargs) cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs) # output_process_fn_grad has a very unfortunate name # We use this function in linalg extensively to postprocess the inputs of functions # that are not completely well-defined. Think svd and muliplying the singular vectors by -1. - # CPU and CUDA implementations of the SVD can return valid SVDs that are different. + # CPU and GPU implementations of the SVD can return valid SVDs that are different. # We use this function to compare them. - cuda_results = sample.output_process_fn_grad(cuda_results) + gpu_results = sample.output_process_fn_grad(gpu_results) cpu_results = cpu_sample.output_process_fn_grad(cpu_results) # Lower tolerance because we are running this as a `@slowTest` # Don't want the periodic tests to fail frequently - self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3) + self.assertEqual(gpu_results, cpu_results, atol=1e-3, rtol=1e-3) # Tests that experimental Python References can propagate shape, dtype, # and device metadata properly. @@ -637,7 +643,7 @@ def test_python_ref_torch_fallback(self, device, dtype, op): self._ref_test_helper(contextlib.nullcontext, device, dtype, op) @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") - @onlyCUDA + @onlyGPU @ops(python_ref_db) @parametrize("executor", ["aten"]) @skipIfTorchInductor("Takes too long for inductor") @@ -677,6 +683,7 @@ def test_errors(self, device, op): out = op(si.input, *si.args, **si.kwargs) self.assertFalse(isinstance(out, type(NotImplemented))) + @skipXPU @skipMeta @onlyNativeDeviceTypesAnd(["hpu"]) @ops( @@ -879,10 +886,13 @@ def _extract_strides(out): return tuple(t.stride() for t in out) # Extracts data pointers from a tensor or iterable of tensors into a tuple - # NOTE: only extracts on the CPU and CUDA device types since some + # NOTE: only extracts on the CPU and GPU device types since some # device types don't have storage def _extract_data_ptrs(out): - if self.device_type != "cpu" and self.device_type != "cuda": + if ( + self.device_type != "cpu" + and self.device_type in GPU_TYPES + ): return () if isinstance(out, torch.Tensor): @@ -1007,10 +1017,13 @@ def _extract_strides(out): return tuple(t.stride() for t in out) # Extracts data pointers from a tensor or iterable of tensors into a tuple - # NOTE: only extracts on the CPU and CUDA device types since some + # NOTE: only extracts on the CPU and GPU device types since some # device types don't have storage def _extract_data_ptrs(out): - if self.device_type != "cpu" and self.device_type != "cuda": + if ( + self.device_type != "cpu" + and self.device_type not in GPU_TYPES + ): return () if isinstance(out, torch.Tensor): @@ -1085,8 +1098,8 @@ def _case_two_transform(t): wrong_device = None if torch.device(device).type != "cpu": wrong_device = "cpu" - elif torch.cuda.is_available(): - wrong_device = "cuda" + elif HAS_GPU: + wrong_device = GPU_TYPE factory_fn_msg = ( "\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its " @@ -1469,8 +1482,9 @@ def convert_boolean_tensors(x): self.assertEqual(expect, actual) # Validates that each OpInfo specifies its forward and backward dtypes - # correctly for CPU and CUDA devices + # correctly for CPU and GPU devices @skipMeta + @skipXPU @onlyNativeDeviceTypesAnd(["hpu"]) @ops(ops_and_refs, dtypes=OpDTypes.none) def test_dtypes(self, device, op): @@ -1824,6 +1838,7 @@ def test_forward_ad(self, device, dtype, op): op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual ) + @skipXPU @ops(op_db, allowed_dtypes=(torch.float,)) def test_cow_input(self, device, dtype, op): samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) @@ -2140,6 +2155,7 @@ def clone_and_perform_view(input, **kwargs): self.assertEqual(tensor.grad, cloned1_tensor.grad) @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,)) + @skipXPU def test_conj_view(self, device, dtype, op): if not op.test_conjugated_samples: self.skipTest("Operation doesn't support conjugated inputs.") @@ -2181,6 +2197,7 @@ def test_neg_view(self, device, dtype, op): ) @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,)) + @skipXPU def test_neg_conj_view(self, device, dtype, op): if not op.test_neg_view: self.skipTest("Operation not tested with tensors with negative bit.") @@ -2520,6 +2537,7 @@ def test_refs_are_in_decomp_table(self, op): # TODO: investigate/fix fake_autocast_device_skips["cpu"] = {"linalg.pinv"} fake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"} +fake_autocast_device_skips["xpu"] = {"linalg.pinv", "pinverse"} dynamic_output_op_tests = ( @@ -2806,7 +2824,7 @@ def _test_fake_crossref_helper(self, device, dtype, op, context): except torch._subclasses.fake_tensor.UnsupportedOperatorException: pass - @onlyCUDA + @onlyGPU @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) @skipOps( "TestFakeTensor", "test_fake_crossref_backward_no_amp", fake_backward_xfails @@ -2814,7 +2832,7 @@ def _test_fake_crossref_helper(self, device, dtype, op, context): def test_fake_crossref_backward_no_amp(self, device, dtype, op): self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext) - @onlyCUDA + @onlyGPU @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) @skipOps( "TestFakeTensor", @@ -2822,7 +2840,7 @@ def test_fake_crossref_backward_no_amp(self, device, dtype, op): fake_backward_xfails | fake_autocast_backward_xfails, ) def test_fake_crossref_backward_amp(self, device, dtype, op): - self._test_fake_crossref_helper(device, dtype, op, torch.cuda.amp.autocast) + self._test_fake_crossref_helper(device, dtype, op, get_gpu_autocast()) @ops([op for op in ops_and_refs if op.is_factory_function]) def test_strided_layout(self, device, dtype, op): diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 7309dd51843a0..09cbeed49fd7e 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -59,6 +59,8 @@ TEST_WITH_UBSAN, TEST_XPU, TestCase, + GPU_TYPE, + GPU_TYPES, ) @@ -823,7 +825,7 @@ def get_desired_device_type_test_bases( test_bases = device_type_test_bases.copy() if allow_mps and TEST_MPS and MPSTestBase not in test_bases: test_bases.append(MPSTestBase) - if allow_xpu and TEST_XPU and XPUTestBase not in test_bases: + if (allow_xpu or only_for == "xpu") and TEST_XPU and XPUTestBase not in test_bases: test_bases.append(XPUTestBase) if TEST_HPU and HPUTestBase not in test_bases: test_bases.append(HPUTestBase) @@ -992,6 +994,9 @@ class OpDTypes(Enum): any_common_cpu_cuda_one = ( 6 # Test precisely one supported dtype that is common to both cuda and cpu ) + any_common_cpu_xpu_one = ( + 7 # Test precisely one supported dtype that is common to both xpu and cpu + ) # Arbitrary order @@ -1122,7 +1127,15 @@ def _parametrize_test(self, test, generic_cls, device_cls): } else: dtypes = {} - + elif self.opinfo_dtypes == OpDTypes.any_common_cpu_xpu_one: + # Tries to pick a dtype that supports both CPU and CUDA + supported = set(op.dtypes).intersection(op.dtypesIfXPU) + if supported: + dtypes = { + next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported) + } + else: + dtypes = {} elif self.opinfo_dtypes == OpDTypes.none: dtypes = {None} else: @@ -1396,14 +1409,19 @@ def efail_fn(slf, *args, **kwargs): class onlyOn: - def __init__(self, device_type): - self.device_type = device_type + def __init__(self, device_type: Union[str, List[str]]): + self.device_types = [] + if isinstance(device_type, str): + self.device_types.append(device_type) + else: + assert isinstance(device_type, list) + self.device_types = device_type def __call__(self, fn): @wraps(fn) def only_fn(slf, *args, **kwargs): - if self.device_type != slf.device_type: - reason = f"Only runs on {self.device_type}" + if slf.device_type not in self.device_types: + reason = f"Only runs on {self.device_types}" raise unittest.SkipTest(reason) return fn(slf, *args, **kwargs) @@ -1625,6 +1643,10 @@ def onlyHPU(fn): return onlyOn("hpu")(fn) +def onlyGPU(fn): + return onlyOn(GPU_TYPES)(fn) + + def onlyPRIVATEUSE1(fn): device_type = torch._C._get_privateuse1_backend_name() device_mod = getattr(torch, device_type, None) @@ -1646,6 +1668,9 @@ def only_fn(self, *args, **kwargs): return only_fn + + + def disablecuDNN(fn): @wraps(fn) def disable_cudnn(self, *args, **kwargs): @@ -1943,6 +1968,10 @@ def skipMeta(fn): return skipMetaIf(True, "test doesn't work with meta tensors")(fn) +def skipXPU(fn): + return skipXPUIf(True, "test doesn't work with XPU tensors")(fn) + + def skipXLA(fn): return skipXLAIf(True, "Marked as skipped for XLA")(fn) @@ -1971,3 +2000,15 @@ def get_all_device_types() -> List[str]: and torch.cuda.get_device_capability() >= (8, 0), "Requires CUDA and Triton", ) + + +def any_common_cpu_device_one(): + return ( + OpDTypes.any_common_cpu_xpu_one + if TEST_XPU + else OpDTypes.any_common_cpu_cuda_one + ) + + +def is_gpu_device(devices: List[str]): + return "cuda" in devices or "xpu" in devices diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index d1acb9ff9f3c6..0d4f304459f02 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -24767,4 +24767,4 @@ def skipOps(test_case_name, base_test_name, to_skip): # This decorator doesn't modify fn in any way def wrapped(fn): return fn - return wrapped + return wrapped \ No newline at end of file diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index c7b5f595e5f3e..251aff22b4ff9 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -5473,7 +5473,6 @@ def repl_frame(m): s = re.sub(r" +$", "", s, flags=re.MULTILINE) return s - @contextmanager def check_leaked_tensors(limit=1, matched_type=torch.Tensor): """Wrap around operations you want to ensure are not leaking tensor memory. @@ -5516,7 +5515,6 @@ def match_obj(obj): finally: gc.set_debug(0) - def remove_cpp_extensions_build_root(): """ Removes the default root folder under which extensions are built. @@ -5550,3 +5548,30 @@ def load_inline(*args, **kwargs): return func(*args, load_inline=load_inline, **kwargs) return wrapper + + +GPU_TYPES = ["cuda", "xpu"] + +@functools.lru_cache(None) +def get_gpu_type(): + avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] + assert len(avail_gpus) <= 1 + gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop() + return gpu_type + +HAS_CUDA = torch.cuda.is_available() + +HAS_XPU = torch.xpu.is_available() + +HAS_GPU = HAS_CUDA or HAS_XPU + +GPU_TYPE = get_gpu_type() + +HAS_MULTIGPU = any( + getattr(torch, gpu).is_available() and getattr(torch, gpu).device_count() >= 2 + for gpu in GPU_TYPES +) + +def get_gpu_autocast(): + return torch.cuda.amp.autocast if HAS_CUDA else torch.autocast + diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 70d22e815cd29..cf6539d2bb304 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -115,7 +115,8 @@ def is_active(self, cls_name, test_name, device_type, dtype, param_kwargs): self.active_if and (self.cls_name is None or self.cls_name == cls_name) and (self.test_name is None or self.test_name == test_name) - and (self.device_type is None or self.device_type == device_type) + and (self.device_type is None or (self.device_type == device_type + if isinstance(self.device_type, str) else device_type in self.device_type)) and (self.dtypes is None or dtype in self.dtypes) # Support callables over kwargs to determine if the decorator is active. and ( @@ -747,6 +748,9 @@ class OpInfo: # backward dtypes this function is expected to work with on CUDA backward_dtypesIfCUDA: _dispatch_dtypes = None + # backward dtypes this function is expected to work with on XPU + backward_dtypesIfXPU: _dispatch_dtypes = None + # backward dtypes this function is expected to work with on ROCM backward_dtypesIfROCM: _dispatch_dtypes = None @@ -971,6 +975,23 @@ def __post_init__(self): else self.dtypes ) ) + + self.backward_dtypesIfXPU = ( + set(self.backward_dtypesIfXPU) + if self.backward_dtypesIfXPU is not None + else ( + self.backward_dtypesIfCUDA + if self.backward_dtypesIfCUDA is not None + else self.backward_dtypes + if self.backward_dtypes is not None + else self.dtypesIfXPU + if self.dtypesIfXPU is not None + else self.dtypesIfCUDA + if self.dtypesIfCUDA is not None + else self.dtypes + ) + ) + self.backward_dtypesIfHpu = ( set(self.backward_dtypesIfHpu) if self.backward_dtypesIfHpu is not None @@ -990,6 +1011,7 @@ def __post_init__(self): self.dtypesIfCUDA = ( set(self.dtypesIfCUDA) if self.dtypesIfCUDA is not None else self.dtypes ) + self.dtypesIfROCM = ( set(self.dtypesIfROCM) if self.dtypesIfROCM is not None @@ -1533,6 +1555,8 @@ def supported_backward_dtypes(self, device_type): if TEST_WITH_ROCM else self.backward_dtypesIfCUDA ) + elif device_type == "xpu": + backward_dtypes = self.backward_dtypesIfXPU elif device_type == "hpu": backward_dtypes = self.backward_dtypesIfHpu else: