|
11 | 11 | from .functional_impl import Lfilter |
12 | 12 |
|
13 | 13 |
|
14 | | -def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32): |
15 | | - """ Generates random tensors given a seed and size |
16 | | - https://en.wikipedia.org/wiki/Linear_congruential_generator |
17 | | - X_{n + 1} = (a * X_n + c) % m |
18 | | - Using Borland C/C++ values |
19 | | -
|
20 | | - The tensor will have values between [0,1) |
21 | | - Inputs: |
22 | | - seed (int): an int |
23 | | - size (Tuple[int]): the size of the output tensor |
24 | | - a (int): the multiplier constant to the generator |
25 | | - c (int): the additive constant to the generator |
26 | | - m (int): the modulus constant to the generator |
27 | | - """ |
28 | | - num_elements = 1 |
29 | | - for s in size: |
30 | | - num_elements *= s |
31 | | - |
32 | | - arr = [(a * seed + c) % m] |
33 | | - for i in range(num_elements - 1): |
34 | | - arr.append((a * arr[i] + c) % m) |
35 | | - |
36 | | - return torch.tensor(arr).float().view(size) / m |
37 | | - |
38 | | - |
39 | 14 | class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): |
40 | 15 | dtype = torch.float32 |
41 | 16 | device = torch.device('cpu') |
@@ -63,242 +38,6 @@ def test_two_channels(self): |
63 | 38 | torch.testing.assert_allclose(computed, expected) |
64 | 39 |
|
65 | 40 |
|
66 | | -def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8): |
67 | | - # trim sound for case when constructed signal is shorter than original |
68 | | - sound = sound[..., :estimate.size(-1)] |
69 | | - torch.testing.assert_allclose(estimate, sound, atol=atol, rtol=rtol) |
70 | | - |
71 | | - |
72 | | -def _test_istft_is_inverse_of_stft(kwargs): |
73 | | - # generates a random sound signal for each tril and then does the stft/istft |
74 | | - # operation to check whether we can reconstruct signal |
75 | | - for data_size in [(2, 20), (3, 15), (4, 10)]: |
76 | | - for i in range(100): |
77 | | - |
78 | | - sound = random_float_tensor(i, data_size) |
79 | | - |
80 | | - stft = torch.stft(sound, **kwargs) |
81 | | - estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) |
82 | | - |
83 | | - _compare_estimate(sound, estimate) |
84 | | - |
85 | | - |
86 | | -class TestIstft(common_utils.TorchaudioTestCase): |
87 | | - """Test suite for correctness of istft with various input""" |
88 | | - number_of_trials = 100 |
89 | | - |
90 | | - def test_istft_is_inverse_of_stft1(self): |
91 | | - # hann_window, centered, normalized, onesided |
92 | | - kwargs1 = { |
93 | | - 'n_fft': 12, |
94 | | - 'hop_length': 4, |
95 | | - 'win_length': 12, |
96 | | - 'window': torch.hann_window(12), |
97 | | - 'center': True, |
98 | | - 'pad_mode': 'reflect', |
99 | | - 'normalized': True, |
100 | | - 'onesided': True, |
101 | | - } |
102 | | - _test_istft_is_inverse_of_stft(kwargs1) |
103 | | - |
104 | | - def test_istft_is_inverse_of_stft2(self): |
105 | | - # hann_window, centered, not normalized, not onesided |
106 | | - kwargs2 = { |
107 | | - 'n_fft': 12, |
108 | | - 'hop_length': 2, |
109 | | - 'win_length': 8, |
110 | | - 'window': torch.hann_window(8), |
111 | | - 'center': True, |
112 | | - 'pad_mode': 'reflect', |
113 | | - 'normalized': False, |
114 | | - 'onesided': False, |
115 | | - } |
116 | | - _test_istft_is_inverse_of_stft(kwargs2) |
117 | | - |
118 | | - def test_istft_is_inverse_of_stft3(self): |
119 | | - # hamming_window, centered, normalized, not onesided |
120 | | - kwargs3 = { |
121 | | - 'n_fft': 15, |
122 | | - 'hop_length': 3, |
123 | | - 'win_length': 11, |
124 | | - 'window': torch.hamming_window(11), |
125 | | - 'center': True, |
126 | | - 'pad_mode': 'constant', |
127 | | - 'normalized': True, |
128 | | - 'onesided': False, |
129 | | - } |
130 | | - _test_istft_is_inverse_of_stft(kwargs3) |
131 | | - |
132 | | - def test_istft_is_inverse_of_stft4(self): |
133 | | - # hamming_window, not centered, not normalized, onesided |
134 | | - # window same size as n_fft |
135 | | - kwargs4 = { |
136 | | - 'n_fft': 5, |
137 | | - 'hop_length': 2, |
138 | | - 'win_length': 5, |
139 | | - 'window': torch.hamming_window(5), |
140 | | - 'center': False, |
141 | | - 'pad_mode': 'constant', |
142 | | - 'normalized': False, |
143 | | - 'onesided': True, |
144 | | - } |
145 | | - _test_istft_is_inverse_of_stft(kwargs4) |
146 | | - |
147 | | - def test_istft_is_inverse_of_stft5(self): |
148 | | - # hamming_window, not centered, not normalized, not onesided |
149 | | - # window same size as n_fft |
150 | | - kwargs5 = { |
151 | | - 'n_fft': 3, |
152 | | - 'hop_length': 2, |
153 | | - 'win_length': 3, |
154 | | - 'window': torch.hamming_window(3), |
155 | | - 'center': False, |
156 | | - 'pad_mode': 'reflect', |
157 | | - 'normalized': False, |
158 | | - 'onesided': False, |
159 | | - } |
160 | | - _test_istft_is_inverse_of_stft(kwargs5) |
161 | | - |
162 | | - def test_istft_of_ones(self): |
163 | | - # stft = torch.stft(torch.ones(4), 4) |
164 | | - stft = torch.tensor([ |
165 | | - [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], |
166 | | - [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], |
167 | | - [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] |
168 | | - ]) |
169 | | - |
170 | | - estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) |
171 | | - _compare_estimate(torch.ones(4), estimate) |
172 | | - |
173 | | - def test_istft_of_zeros(self): |
174 | | - # stft = torch.stft(torch.zeros(4), 4) |
175 | | - stft = torch.zeros((3, 5, 2)) |
176 | | - |
177 | | - estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) |
178 | | - _compare_estimate(torch.zeros(4), estimate) |
179 | | - |
180 | | - def test_istft_requires_overlap_windows(self): |
181 | | - # the window is size 1 but it hops 20 so there is a gap which throw an error |
182 | | - stft = torch.zeros((3, 5, 2)) |
183 | | - self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4, |
184 | | - hop_length=20, win_length=1, window=torch.ones(1)) |
185 | | - |
186 | | - def test_istft_requires_nola(self): |
187 | | - stft = torch.zeros((3, 5, 2)) |
188 | | - kwargs_ok = { |
189 | | - 'n_fft': 4, |
190 | | - 'win_length': 4, |
191 | | - 'window': torch.ones(4), |
192 | | - } |
193 | | - |
194 | | - kwargs_not_ok = { |
195 | | - 'n_fft': 4, |
196 | | - 'win_length': 4, |
197 | | - 'window': torch.zeros(4), |
198 | | - } |
199 | | - |
200 | | - # A window of ones meets NOLA but a window of zeros does not. This should |
201 | | - # throw an error. |
202 | | - torchaudio.functional.istft(stft, **kwargs_ok) |
203 | | - self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok) |
204 | | - |
205 | | - def test_istft_requires_non_empty(self): |
206 | | - self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2) |
207 | | - self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2) |
208 | | - |
209 | | - def _test_istft_of_sine(self, amplitude, L, n): |
210 | | - # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L |
211 | | - x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype()) |
212 | | - sound = amplitude * torch.sin(2 * math.pi / L * x * n) |
213 | | - # stft = torch.stft(sound, L, hop_length=L, win_length=L, |
214 | | - # window=torch.ones(L), center=False, normalized=False) |
215 | | - stft = torch.zeros((L // 2 + 1, 2, 2)) |
216 | | - stft_largest_val = (amplitude * L) / 2.0 |
217 | | - if n < stft.size(0): |
218 | | - stft[n, :, 1] = -stft_largest_val |
219 | | - |
220 | | - if 0 <= L - n < stft.size(0): |
221 | | - # symmetric about L // 2 |
222 | | - stft[L - n, :, 1] = stft_largest_val |
223 | | - |
224 | | - estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L, |
225 | | - window=torch.ones(L), center=False, normalized=False) |
226 | | - # There is a larger error due to the scaling of amplitude |
227 | | - _compare_estimate(sound, estimate, atol=1e-3) |
228 | | - |
229 | | - def test_istft_of_sine(self): |
230 | | - self._test_istft_of_sine(amplitude=123, L=5, n=1) |
231 | | - self._test_istft_of_sine(amplitude=150, L=5, n=2) |
232 | | - self._test_istft_of_sine(amplitude=111, L=5, n=3) |
233 | | - self._test_istft_of_sine(amplitude=160, L=7, n=4) |
234 | | - self._test_istft_of_sine(amplitude=145, L=8, n=5) |
235 | | - self._test_istft_of_sine(amplitude=80, L=9, n=6) |
236 | | - self._test_istft_of_sine(amplitude=99, L=10, n=7) |
237 | | - |
238 | | - def _test_linearity_of_istft(self, data_size, kwargs, atol=1e-6, rtol=1e-8): |
239 | | - for i in range(self.number_of_trials): |
240 | | - tensor1 = random_float_tensor(i, data_size) |
241 | | - tensor2 = random_float_tensor(i * 2, data_size) |
242 | | - a, b = torch.rand(2) |
243 | | - istft1 = torchaudio.functional.istft(tensor1, **kwargs) |
244 | | - istft2 = torchaudio.functional.istft(tensor2, **kwargs) |
245 | | - istft = a * istft1 + b * istft2 |
246 | | - estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2, **kwargs) |
247 | | - _compare_estimate(istft, estimate, atol, rtol) |
248 | | - |
249 | | - def test_linearity_of_istft1(self): |
250 | | - # hann_window, centered, normalized, onesided |
251 | | - kwargs1 = { |
252 | | - 'n_fft': 12, |
253 | | - 'window': torch.hann_window(12), |
254 | | - 'center': True, |
255 | | - 'pad_mode': 'reflect', |
256 | | - 'normalized': True, |
257 | | - 'onesided': True, |
258 | | - } |
259 | | - data_size = (2, 7, 7, 2) |
260 | | - self._test_linearity_of_istft(data_size, kwargs1) |
261 | | - |
262 | | - def test_linearity_of_istft2(self): |
263 | | - # hann_window, centered, not normalized, not onesided |
264 | | - kwargs2 = { |
265 | | - 'n_fft': 12, |
266 | | - 'window': torch.hann_window(12), |
267 | | - 'center': True, |
268 | | - 'pad_mode': 'reflect', |
269 | | - 'normalized': False, |
270 | | - 'onesided': False, |
271 | | - } |
272 | | - data_size = (2, 12, 7, 2) |
273 | | - self._test_linearity_of_istft(data_size, kwargs2) |
274 | | - |
275 | | - def test_linearity_of_istft3(self): |
276 | | - # hamming_window, centered, normalized, not onesided |
277 | | - kwargs3 = { |
278 | | - 'n_fft': 12, |
279 | | - 'window': torch.hamming_window(12), |
280 | | - 'center': True, |
281 | | - 'pad_mode': 'constant', |
282 | | - 'normalized': True, |
283 | | - 'onesided': False, |
284 | | - } |
285 | | - data_size = (2, 12, 7, 2) |
286 | | - self._test_linearity_of_istft(data_size, kwargs3) |
287 | | - |
288 | | - def test_linearity_of_istft4(self): |
289 | | - # hamming_window, not centered, not normalized, onesided |
290 | | - kwargs4 = { |
291 | | - 'n_fft': 12, |
292 | | - 'window': torch.hamming_window(12), |
293 | | - 'center': False, |
294 | | - 'pad_mode': 'constant', |
295 | | - 'normalized': False, |
296 | | - 'onesided': True, |
297 | | - } |
298 | | - data_size = (2, 7, 3, 2) |
299 | | - self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8) |
300 | | - |
301 | | - |
302 | 41 | class TestDetectPitchFrequency(common_utils.TorchaudioTestCase): |
303 | 42 | @parameterized.expand([(100,), (440,)]) |
304 | 43 | def test_pitch(self, frequency): |
|
0 commit comments