@@ -93,73 +93,17 @@ def test_spectogram_grad_at_zero(self, power):
9393        spec .sum ().backward ()
9494        assert  not  x .grad .isnan ().sum ()
9595
96- 
97- class  FunctionalComplex (TestBaseMixin ):
98-     complex_dtype  =  None 
99-     real_dtype  =  None 
100-     device  =  None 
101- 
102-     @nested_params ( 
103-         [0.5 , 1.01 , 1.3 ], 
104-         [True , False ], 
105-     ) 
106-     def  test_phase_vocoder_shape (self , rate , test_pseudo_complex ):
107-         """Verify the output shape of phase vocoder""" 
108-         hop_length  =  256 
109-         num_freq  =  1025 
110-         num_frames  =  400 
111-         batch_size  =  2 
112- 
113-         torch .random .manual_seed (42 )
114-         spec  =  torch .randn (
115-             batch_size , num_freq , num_frames , dtype = self .complex_dtype , device = self .device )
116-         if  test_pseudo_complex :
117-             spec  =  torch .view_as_real (spec )
118- 
119-         phase_advance  =  torch .linspace (
120-             0 ,
121-             np .pi  *  hop_length ,
122-             num_freq ,
123-             dtype = self .real_dtype , device = self .device )[..., None ]
124- 
125-         spec_stretch  =  F .phase_vocoder (spec , rate = rate , phase_advance = phase_advance )
126- 
127-         assert  spec .dim () ==  spec_stretch .dim ()
128-         expected_shape  =  torch .Size ([batch_size , num_freq , int (np .ceil (num_frames  /  rate ))])
129-         output_shape  =  (torch .view_as_complex (spec_stretch ) if  test_pseudo_complex  else  spec_stretch ).shape 
130-         assert  output_shape  ==  expected_shape 
131- 
132- 
133- class  FunctionalCPUOnly (TestBaseMixin ):
134-     def  test_create_fb_matrix_no_warning_high_n_freq (self ):
135-         with  warnings .catch_warnings (record = True ) as  w :
136-             warnings .simplefilter ("always" )
137-             F .create_fb_matrix (288 , 0 , 8000 , 128 , 16000 )
138-         assert  len (w ) ==  0 
139- 
140-     def  test_create_fb_matrix_no_warning_low_n_mels (self ):
141-         with  warnings .catch_warnings (record = True ) as  w :
142-             warnings .simplefilter ("always" )
143-             F .create_fb_matrix (201 , 0 , 8000 , 89 , 16000 )
144-         assert  len (w ) ==  0 
145- 
146-     def  test_create_fb_matrix_warning (self ):
147-         with  warnings .catch_warnings (record = True ) as  w :
148-             warnings .simplefilter ("always" )
149-             F .create_fb_matrix (201 , 0 , 8000 , 128 , 16000 )
150-         assert  len (w ) ==  1 
151- 
15296    def  test_compute_deltas_one_channel (self ):
153-         specgram  =  torch .tensor ([[[1.0 , 2.0 , 3.0 , 4.0 ]]])
154-         expected  =  torch .tensor ([[[0.5 , 1.0 , 1.0 , 0.5 ]]])
97+         specgram  =  torch .tensor ([[[1.0 , 2.0 , 3.0 , 4.0 ]]],  dtype = self . dtype ,  device = self . device )
98+         expected  =  torch .tensor ([[[0.5 , 1.0 , 1.0 , 0.5 ]]],  dtype = self . dtype ,  device = self . device )
15599        computed  =  F .compute_deltas (specgram , win_length = 3 )
156100        self .assertEqual (computed , expected )
157101
158102    def  test_compute_deltas_two_channels (self ):
159103        specgram  =  torch .tensor ([[[1.0 , 2.0 , 3.0 , 4.0 ],
160-                                   [1.0 , 2.0 , 3.0 , 4.0 ]]])
104+                                   [1.0 , 2.0 , 3.0 , 4.0 ]]],  dtype = self . dtype ,  device = self . device )
161105        expected  =  torch .tensor ([[[0.5 , 1.0 , 1.0 , 0.5 ],
162-                                   [0.5 , 1.0 , 1.0 , 0.5 ]]])
106+                                   [0.5 , 1.0 , 1.0 , 0.5 ]]],  dtype = self . dtype ,  device = self . device )
163107        computed  =  F .compute_deltas (specgram , win_length = 3 )
164108        self .assertEqual (computed , expected )
165109
@@ -190,7 +134,7 @@ def test_amplitude_to_DB_reversible(self, shape):
190134        db_mult  =  math .log10 (max (amin , ref ))
191135
192136        torch .manual_seed (0 )
193-         spec  =  torch .rand (* shape ) *  200 
137+         spec  =  torch .rand (* shape ,  dtype = self . dtype ,  device = self . device ) *  200 
194138
195139        # Spectrogram amplitude -> DB -> amplitude 
196140        db  =  F .amplitude_to_DB (spec , amplitude_mult , amin , db_mult , top_db = None )
@@ -218,7 +162,7 @@ def test_amplitude_to_DB_top_db_clamp(self, shape):
218162        # each spectrogram still need to be predictable. The max determines the 
219163        # decibel cutoff, and the distance from the min must be large enough 
220164        # that it triggers a clamp. 
221-         spec  =  torch .rand (* shape )
165+         spec  =  torch .rand (* shape ,  dtype = self . dtype ,  device = self . device )
222166        # Ensure each spectrogram has a min of 0 and a max of 1. 
223167        spec  -=  spec .amin ([- 2 , - 1 ])[..., None , None ]
224168        spec  /=  spec .amax ([- 2 , - 1 ])[..., None , None ]
@@ -245,7 +189,7 @@ def test_amplitude_to_DB_top_db_clamp(self, shape):
245189    ) 
246190    def  test_complex_norm (self , shape , power ):
247191        torch .random .manual_seed (42 )
248-         complex_tensor  =  torch .randn (* shape )
192+         complex_tensor  =  torch .randn (* shape ,  dtype = self . dtype ,  device = self . device )
249193        expected_norm_tensor  =  complex_tensor .pow (2 ).sum (- 1 ).pow (power  /  2 )
250194        norm_tensor  =  F .complex_norm (complex_tensor , power )
251195        self .assertEqual (norm_tensor , expected_norm_tensor , atol = 1e-5 , rtol = 1e-5 )
@@ -255,7 +199,7 @@ def test_complex_norm(self, shape, power):
255199    ) 
256200    def  test_mask_along_axis (self , shape , mask_param , mask_value , axis ):
257201        torch .random .manual_seed (42 )
258-         specgram  =  torch .randn (* shape )
202+         specgram  =  torch .randn (* shape ,  dtype = self . dtype ,  device = self . device )
259203        mask_specgram  =  F .mask_along_axis (specgram , mask_param , mask_value , axis )
260204
261205        other_axis  =  1  if  axis  ==  2  else  2 
@@ -271,7 +215,7 @@ def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
271215    @parameterized .expand (list (itertools .product ([100 ], [0. , 30. ], [2 , 3 ]))) 
272216    def  test_mask_along_axis_iid (self , mask_param , mask_value , axis ):
273217        torch .random .manual_seed (42 )
274-         specgrams  =  torch .randn (4 , 2 , 1025 , 400 )
218+         specgrams  =  torch .randn (4 , 2 , 1025 , 400 ,  dtype = self . dtype ,  device = self . device )
275219
276220        mask_specgrams  =  F .mask_along_axis_iid (specgrams , mask_param , mask_value , axis )
277221
@@ -282,3 +226,59 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
282226
283227        assert  mask_specgrams .size () ==  specgrams .size ()
284228        assert  (num_masked_columns  <  mask_param ).sum () ==  num_masked_columns .numel ()
229+ 
230+ 
231+ class  FunctionalComplex (TestBaseMixin ):
232+     complex_dtype  =  None 
233+     real_dtype  =  None 
234+     device  =  None 
235+ 
236+     @nested_params ( 
237+         [0.5 , 1.01 , 1.3 ], 
238+         [True , False ], 
239+     ) 
240+     def  test_phase_vocoder_shape (self , rate , test_pseudo_complex ):
241+         """Verify the output shape of phase vocoder""" 
242+         hop_length  =  256 
243+         num_freq  =  1025 
244+         num_frames  =  400 
245+         batch_size  =  2 
246+ 
247+         torch .random .manual_seed (42 )
248+         spec  =  torch .randn (
249+             batch_size , num_freq , num_frames , dtype = self .complex_dtype , device = self .device )
250+         if  test_pseudo_complex :
251+             spec  =  torch .view_as_real (spec )
252+ 
253+         phase_advance  =  torch .linspace (
254+             0 ,
255+             np .pi  *  hop_length ,
256+             num_freq ,
257+             dtype = self .real_dtype , device = self .device )[..., None ]
258+ 
259+         spec_stretch  =  F .phase_vocoder (spec , rate = rate , phase_advance = phase_advance )
260+ 
261+         assert  spec .dim () ==  spec_stretch .dim ()
262+         expected_shape  =  torch .Size ([batch_size , num_freq , int (np .ceil (num_frames  /  rate ))])
263+         output_shape  =  (torch .view_as_complex (spec_stretch ) if  test_pseudo_complex  else  spec_stretch ).shape 
264+         assert  output_shape  ==  expected_shape 
265+ 
266+ 
267+ class  FunctionalCPUOnly (TestBaseMixin ):
268+     def  test_create_fb_matrix_no_warning_high_n_freq (self ):
269+         with  warnings .catch_warnings (record = True ) as  w :
270+             warnings .simplefilter ("always" )
271+             F .create_fb_matrix (288 , 0 , 8000 , 128 , 16000 )
272+         assert  len (w ) ==  0 
273+ 
274+     def  test_create_fb_matrix_no_warning_low_n_mels (self ):
275+         with  warnings .catch_warnings (record = True ) as  w :
276+             warnings .simplefilter ("always" )
277+             F .create_fb_matrix (201 , 0 , 8000 , 89 , 16000 )
278+         assert  len (w ) ==  0 
279+ 
280+     def  test_create_fb_matrix_warning (self ):
281+         with  warnings .catch_warnings (record = True ) as  w :
282+             warnings .simplefilter ("always" )
283+             F .create_fb_matrix (201 , 0 , 8000 , 128 , 16000 )
284+         assert  len (w ) ==  1 
0 commit comments