Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions test/test_functional_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,9 @@ def test_lfilter_basic(self):
def test_lfilter_basic_double(self):
self._test_lfilter_basic(torch.float64, torch.device("cpu"))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_lfilter_basic_gpu(self):
if torch.cuda.is_available():
self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))
else:
print("skipping GPU test for lfilter_basic because device not available")
pass
self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))

def _test_lfilter(self, waveform, device):
"""
Expand Down Expand Up @@ -87,16 +84,13 @@ def test_lfilter(self):

self._test_lfilter(waveform, torch.device("cpu"))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_lfilter_gpu(self):
if torch.cuda.is_available():
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
cuda0 = torch.device("cuda:0")
cuda_waveform = waveform.cuda(device=cuda0)
self._test_lfilter(cuda_waveform, cuda0)
else:
print("skipping GPU test for lfilter because device not available")
pass
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
cuda0 = torch.device("cuda:0")
cuda_waveform = waveform.cuda(device=cuda0)
self._test_lfilter(cuda_waveform, cuda0)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
Expand Down