88from  torchaudio_unittest .common_utils  import  (
99    TestBaseMixin ,
1010    get_whitenoise ,
11+     get_spectrogram ,
12+     nested_params ,
1113)
1214
1315
@@ -23,8 +25,12 @@ def assert_grad(
2325
2426        inputs_  =  []
2527        for  i  in  inputs :
26-             i .requires_grad  =  True 
27-             inputs_ .append (i .to (dtype = torch .float64 , device = self .device ))
28+             if  torch .is_tensor (i ):
29+                 i  =  i .to (
30+                     dtype = torch .cdouble  if  i .is_complex () else  torch .double ,
31+                     device = self .device )
32+                 i .requires_grad  =  True 
33+             inputs_ .append (i )
2834        assert  gradcheck (transform , inputs_ )
2935        assert  gradgradcheck (transform , inputs_ , nondet_tol = nondet_tol )
3036
@@ -103,3 +109,23 @@ def test_spectral_centroid(self):
103109        transform  =  T .SpectralCentroid (sample_rate = sample_rate )
104110        waveform  =  get_whitenoise (sample_rate = sample_rate , duration = 0.05 , n_channels = 2 )
105111        self .assert_grad (transform , [waveform ], nondet_tol = 1e-10 )
112+ 
113+     @nested_params ( 
114+         [0.7 , 0.8 , 0.9 , 1.0 , 1.3 ], 
115+         [True , False ], 
116+     ) 
117+     def  test_timestretch (self , rate , test_complex ):
118+         transform  =  T .TimeStretch (fixed_rate = rate )
119+         waveform  =  get_whitenoise (sample_rate = 8000 , duration = 0.05 , n_channels = 2 )
120+         spectrogram  =  get_spectrogram (waveform , n_fft = 400 , power = 1  if  test_complex  else  None )
121+         self .assert_grad (transform , [spectrogram ])
122+ 
123+     @nested_params ( 
124+         [0.7 , 0.8 , 0.9 , 1.0 , 1.3 ], 
125+         [True , False ], 
126+     ) 
127+     def  test_timestretch_override (self , rate , test_complex ):
128+         transform  =  T .TimeStretch ()
129+         waveform  =  get_whitenoise (sample_rate = 8000 , duration = 0.05 , n_channels = 2 )
130+         spectrogram  =  get_spectrogram (waveform , n_fft = 400 , power = 1  if  test_complex  else  None )
131+         self .assert_grad (transform , [spectrogram , rate ])
0 commit comments