Skip to content

Commit 45a146f

Browse files
committed
Run test only on CPU and CUDA
1 parent cc0bca4 commit 45a146f

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

test/test_torch.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11888,7 +11888,7 @@ def test_fft_input_modification(self, device):
1188811888
_ = torch.irfft(half_spectrum_copy, 2, signal_sizes=(2, 2))
1188911889
self.assertEqual(half_spectrum, half_spectrum_copy)
1189011890

11891-
@skipCUDAIfRocm
11891+
@onlyOnCPUAndCUDA
1189211892
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
1189311893
@dtypes(torch.double)
1189411894
def test_istft_round_trip_simple_cases(self, device, dtype):
@@ -11901,7 +11901,7 @@ def _test(input, n_fft, length):
1190111901
_test(torch.ones(4, dtype=dtype, device=device), 4, 4)
1190211902
_test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
1190311903

11904-
@skipCUDAIfRocm
11904+
@onlyOnCPUAndCUDA
1190511905
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
1190611906
@dtypes(torch.double)
1190711907
def test_istft_round_trip_various_params(self, device, dtype):
@@ -11979,6 +11979,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
1197911979
for i, pattern in enumerate(patterns):
1198011980
_test_istft_is_inverse_of_stft(pattern)
1198111981

11982+
@onlyOnCPUAndCUDA
1198211983
def test_istft_throws(self, device):
1198311984
"""istft should throw exception for invalid parameters"""
1198411985
stft = torch.zeros((3, 5, 2), device=device)
@@ -11994,7 +11995,7 @@ def test_istft_throws(self, device):
1199411995
self.assertRaises(AssertionError, torch.istft, torch.zeros((3, 0, 2)), 2)
1199511996
self.assertRaises(AssertionError, torch.istft, torch.zeros((0, 3, 2)), 2)
1199611997

11997-
@skipCUDAIfRocm
11998+
@onlyOnCPUAndCUDA
1199811999
@dtypes(torch.double)
1199912000
def test_istft_of_sine(self, device, dtype):
1200012001
def _test(amplitude, L, n):
@@ -12027,6 +12028,7 @@ def _test(amplitude, L, n):
1202712028
_test(amplitude=80, L=9, n=6)
1202812029
_test(amplitude=99, L=10, n=7)
1202912030

12031+
@onlyOnCPUAndCUDA
1203012032
@dtypes(torch.double)
1203112033
def test_istft_linearity(self, device, dtype):
1203212034
num_trials = 100
@@ -12090,7 +12092,7 @@ def _test(data_size, kwargs):
1209012092
for data_size, kwargs in patterns:
1209112093
_test(data_size, kwargs)
1209212094

12093-
@skipCUDAIfRocm
12095+
@onlyOnCPUAndCUDA
1209412096
def test_batch_istft(self, device):
1209512097
original = torch.tensor([
1209612098
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],

0 commit comments

Comments
 (0)