@@ -33,12 +33,9 @@ def test_lfilter_basic(self):
3333 def test_lfilter_basic_double (self ):
3434 self ._test_lfilter_basic (torch .float64 , torch .device ("cpu" ))
3535
36+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
3637 def test_lfilter_basic_gpu (self ):
37- if torch .cuda .is_available ():
38- self ._test_lfilter_basic (torch .float32 , torch .device ("cuda:0" ))
39- else :
40- print ("skipping GPU test for lfilter_basic because device not available" )
41- pass
38+ self ._test_lfilter_basic (torch .float32 , torch .device ("cuda:0" ))
4239
4340 def _test_lfilter (self , waveform , device ):
4441 """
@@ -87,16 +84,13 @@ def test_lfilter(self):
8784
8885 self ._test_lfilter (waveform , torch .device ("cpu" ))
8986
87+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
9088 def test_lfilter_gpu (self ):
91- if torch .cuda .is_available ():
92- filepath = os .path .join (self .test_dirpath , "assets" , "whitenoise.wav" )
93- waveform , _ = torchaudio .load (filepath , normalization = True )
94- cuda0 = torch .device ("cuda:0" )
95- cuda_waveform = waveform .cuda (device = cuda0 )
96- self ._test_lfilter (cuda_waveform , cuda0 )
97- else :
98- print ("skipping GPU test for lfilter because device not available" )
99- pass
89+ filepath = os .path .join (self .test_dirpath , "assets" , "whitenoise.wav" )
90+ waveform , _ = torchaudio .load (filepath , normalization = True )
91+ cuda0 = torch .device ("cuda:0" )
92+ cuda_waveform = waveform .cuda (device = cuda0 )
93+ self ._test_lfilter (cuda_waveform , cuda0 )
10094
10195 @unittest .skipIf ("sox" not in BACKENDS , "sox not available" )
10296 @AudioBackendScope ("sox" )
0 commit comments