Skip to content

Commit 876334e

Browse files
committed
Fix GPU test skip logic
1 parent bc1ffb1 commit 876334e

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

test/test_functional_filtering.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,9 @@ def test_lfilter_basic(self):
3333
def test_lfilter_basic_double(self):
3434
self._test_lfilter_basic(torch.float64, torch.device("cpu"))
3535

36+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
3637
def test_lfilter_basic_gpu(self):
37-
if torch.cuda.is_available():
38-
self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))
39-
else:
40-
print("skipping GPU test for lfilter_basic because device not available")
41-
pass
38+
self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))
4239

4340
def _test_lfilter(self, waveform, device):
4441
"""
@@ -87,16 +84,13 @@ def test_lfilter(self):
8784

8885
self._test_lfilter(waveform, torch.device("cpu"))
8986

87+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
9088
def test_lfilter_gpu(self):
91-
if torch.cuda.is_available():
92-
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
93-
waveform, _ = torchaudio.load(filepath, normalization=True)
94-
cuda0 = torch.device("cuda:0")
95-
cuda_waveform = waveform.cuda(device=cuda0)
96-
self._test_lfilter(cuda_waveform, cuda0)
97-
else:
98-
print("skipping GPU test for lfilter because device not available")
99-
pass
89+
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
90+
waveform, _ = torchaudio.load(filepath, normalization=True)
91+
cuda0 = torch.device("cuda:0")
92+
cuda_waveform = waveform.cuda(device=cuda0)
93+
self._test_lfilter(cuda_waveform, cuda0)
10094

10195
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
10296
@AudioBackendScope("sox")

0 commit comments

Comments
 (0)