@@ -11888,7 +11888,7 @@ def test_fft_input_modification(self, device):
1188811888 _ = torch.irfft(half_spectrum_copy, 2, signal_sizes=(2, 2))
1188911889 self.assertEqual(half_spectrum, half_spectrum_copy)
1189011890
11891- @skipCUDAIfRocm
11891+ @onlyOnCPUAndCUDA
1189211892 @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
1189311893 @dtypes(torch.double)
1189411894 def test_istft_round_trip_simple_cases(self, device, dtype):
@@ -11901,7 +11901,7 @@ def _test(input, n_fft, length):
1190111901 _test(torch.ones(4, dtype=dtype, device=device), 4, 4)
1190211902 _test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
1190311903
11904- @skipCUDAIfRocm
11904+ @onlyOnCPUAndCUDA
1190511905 @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
1190611906 @dtypes(torch.double)
1190711907 def test_istft_round_trip_various_params(self, device, dtype):
@@ -11979,6 +11979,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
1197911979 for i, pattern in enumerate(patterns):
1198011980 _test_istft_is_inverse_of_stft(pattern)
1198111981
11982+ @onlyOnCPUAndCUDA
1198211983 def test_istft_throws(self, device):
1198311984 """istft should throw exception for invalid parameters"""
1198411985 stft = torch.zeros((3, 5, 2), device=device)
@@ -11994,7 +11995,7 @@ def test_istft_throws(self, device):
1199411995 self.assertRaises(AssertionError, torch.istft, torch.zeros((3, 0, 2)), 2)
1199511996 self.assertRaises(AssertionError, torch.istft, torch.zeros((0, 3, 2)), 2)
1199611997
11997- @skipCUDAIfRocm
11998+ @onlyOnCPUAndCUDA
1199811999 @dtypes(torch.double)
1199912000 def test_istft_of_sine(self, device, dtype):
1200012001 def _test(amplitude, L, n):
@@ -12027,6 +12028,7 @@ def _test(amplitude, L, n):
1202712028 _test(amplitude=80, L=9, n=6)
1202812029 _test(amplitude=99, L=10, n=7)
1202912030
12031+ @onlyOnCPUAndCUDA
1203012032 @dtypes(torch.double)
1203112033 def test_istft_linearity(self, device, dtype):
1203212034 num_trials = 100
@@ -12090,7 +12092,7 @@ def _test(data_size, kwargs):
1209012092 for data_size, kwargs in patterns:
1209112093 _test(data_size, kwargs)
1209212094
12093- @skipCUDAIfRocm
12095+ @onlyOnCPUAndCUDA
1209412096 def test_batch_istft(self, device):
1209512097 original = torch.tensor([
1209612098 [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
0 commit comments