|
7 | 7 | import pytest |
8 | 8 | import unittest |
9 | 9 | import common_utils |
| 10 | +import os |
10 | 11 |
|
11 | 12 | from torchaudio.common_utils import IMPORT_LIBROSA |
12 | 13 |
|
@@ -247,6 +248,19 @@ def test_create_fb(self): |
247 | 248 | self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0) |
248 | 249 | self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0) |
249 | 250 |
|
| 251 | + def test_gain(self): |
| 252 | + test_dirpath, test_dir = common_utils.create_temp_assets_dir() |
| 253 | + test_filepath = os.path.join(test_dirpath, "assets", "sinewave.wav") |
| 254 | + |
| 255 | + waveform, sample_rate = torchaudio.load(test_filepath) |
| 256 | + waveform_gain = F.gain(waveform, 5) |
| 257 | + self.assertTrue(waveform_gain.abs().max().item(), 1.) |
| 258 | + |
| 259 | + test_filepath_sox_gain = os.path.join(test_dirpath, "assets", "sinewave_soxgain5.wav") |
| 260 | + sox_gain_waveform, sox_sr = torchaudio.load(test_filepath_sox_gain) |
| 261 | + |
| 262 | + self.assertTrue(torch.allclose(waveform_gain, sox_gain_waveform)) |
| 263 | + |
250 | 264 |
|
251 | 265 | def _num_stft_bins(signal_len, fft_len, hop_length, pad): |
252 | 266 | return (signal_len + 2 * pad - fft_len + hop_length) // hop_length |
@@ -346,5 +360,6 @@ def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis): |
346 | 360 | assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() |
347 | 361 |
|
348 | 362 |
|
| 363 | + |
349 | 364 | if __name__ == '__main__': |
350 | 365 | unittest.main() |
0 commit comments