diff --git a/test/test_functional.py b/test/test_functional.py index 04ef533f05..02b01620ed 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -2,6 +2,8 @@ import torch import torchaudio +import torchaudio.functional as F +import pytest import unittest import test.common_utils @@ -11,10 +13,6 @@ import numpy as np import librosa -import pytest -import torchaudio.functional as F -xfail = pytest.mark.xfail - class TestFunctional(unittest.TestCase): data_sizes = [(2, 20), (3, 15), (4, 10)] @@ -197,54 +195,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad): return (signal_len + 2 * pad - fft_len + hop_length) // hop_length -@pytest.mark.parametrize('fft_length', [512]) -@pytest.mark.parametrize('hop_length', [256]) -@pytest.mark.parametrize('waveform', [ - (torch.randn(1, 100000)), - (torch.randn(1, 2, 100000)), - pytest.param(torch.randn(1, 100), marks=xfail(raises=RuntimeError)), -]) -@pytest.mark.parametrize('pad_mode', [ - # 'constant', - 'reflect', -]) -@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available') -def test_stft(waveform, fft_length, hop_length, pad_mode): - """ - Test STFT for multi-channel signals. - - Padding: Value in having padding outside of torch.stft? - """ - pad = fft_length // 2 - window = torch.hann_window(fft_length) - complex_spec = F.stft(waveform, - fft_length=fft_length, - hop_length=hop_length, - window=window, - pad_mode=pad_mode) - mag_spec, phase_spec = F.magphase(complex_spec) - - # == Test shape - expected_size = list(waveform.size()[:-1]) - expected_size += [fft_length // 2 + 1, _num_stft_bins( - waveform.size(-1), fft_length, hop_length, pad), 2] - assert complex_spec.dim() == waveform.dim() + 2 - assert complex_spec.size() == torch.Size(expected_size) - - # == Test values - fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode) - # note that librosa *automatically* pad with fft_length // 2. - expected_complex_spec = np.apply_along_axis(librosa.stft, -1, - waveform.numpy(), **fft_config) - expected_mag_spec, _ = librosa.magphase(expected_complex_spec) - # Convert torch to np.complex - complex_spec = complex_spec.numpy() - complex_spec = complex_spec[..., 0] + 1j * complex_spec[..., 1] - - assert np.allclose(complex_spec, expected_complex_spec, atol=1e-5) - assert np.allclose(mag_spec.numpy(), expected_mag_spec, atol=1e-5) - - @pytest.mark.parametrize('rate', [0.5, 1.01, 1.3]) @pytest.mark.parametrize('complex_specgrams', [ torch.randn(1, 2, 1025, 400, 2), diff --git a/test/test_jit.py b/test/test_jit.py index d2652a9dc4..22113a295e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -30,40 +30,18 @@ def _test_script_module(self, tensor, f, *args): self.assertTrue(torch.allclose(jit_out, py_out)) - def test_torchscript_scale(self): - @torch.jit.script - def jit_method(tensor, factor): - # type: (Tensor, int) -> Tensor - return F.scale(tensor, factor) - - tensor = torch.rand((10, 1)) - factor = 2 - - jit_out = jit_method(tensor, factor) - py_out = F.scale(tensor, factor) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_scale(self): - tensor = torch.rand((10, 1), device="cuda") - - self._test_script_module(tensor, transforms.Scale) - def test_torchscript_pad_trim(self): @torch.jit.script - def jit_method(tensor, ch_dim, max_len, len_dim, fill_value): - # type: (Tensor, int, int, int, float) -> Tensor - return F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value) + def jit_method(tensor, max_len, fill_value): + # type: (Tensor, int, float) -> Tensor + return F.pad_trim(tensor, max_len, fill_value) - tensor = torch.rand((10, 1)) - ch_dim = 1 + tensor = torch.rand((1, 10)) max_len = 5 - len_dim = 0 fill_value = 3. - jit_out = jit_method(tensor, ch_dim, max_len, len_dim, fill_value) - py_out = F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value) + jit_out = jit_method(tensor, max_len, fill_value) + py_out = F.pad_trim(tensor, max_len, fill_value) self.assertTrue(torch.allclose(jit_out, py_out)) @@ -74,45 +52,6 @@ def test_scriptmodule_pad_trim(self): self._test_script_module(tensor, transforms.PadTrim, max_len) - def test_torchscript_downmix_mono(self): - @torch.jit.script - def jit_method(tensor, ch_dim): - # type: (Tensor, int) -> Tensor - return F.downmix_mono(tensor, ch_dim) - - tensor = torch.rand((10, 1)) - ch_dim = 1 - - jit_out = jit_method(tensor, ch_dim) - py_out = F.downmix_mono(tensor, ch_dim) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_downmix_mono(self): - tensor = torch.rand((1, 10), device="cuda") - - self._test_script_module(tensor, transforms.DownmixMono) - - def test_torchscript_LC2CL(self): - @torch.jit.script - def jit_method(tensor): - # type: (Tensor) -> Tensor - return F.LC2CL(tensor) - - tensor = torch.rand((10, 1)) - - jit_out = jit_method(tensor) - py_out = F.LC2CL(tensor) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_LC2CL(self): - tensor = torch.rand((10, 1), device="cuda") - - self._test_script_module(tensor, transforms.LC2CL) - def test_torchscript_spectrogram(self): @torch.jit.script def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): @@ -167,7 +106,7 @@ def jit_method(spec, multiplier, amin, db_multiplier, top_db): # type: (Tensor, float, float, float, Optional[float]) -> Tensor return F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db) - spec = torch.rand((10, 1)) + spec = torch.rand((6, 201)) multiplier = 10. amin = 1e-10 db_multiplier = 0. @@ -180,7 +119,7 @@ def jit_method(spec, multiplier, amin, db_multiplier, top_db): @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_SpectrogramToDB(self): - spec = torch.rand((10, 1), device="cuda") + spec = torch.rand((6, 201), device="cuda") self._test_script_module(spec, transforms.SpectrogramToDB) @@ -211,32 +150,13 @@ def test_scriptmodule_MelSpectrogram(self): self._test_script_module(tensor, transforms.MelSpectrogram) - def test_torchscript_BLC2CBL(self): - @torch.jit.script - def jit_method(tensor): - # type: (Tensor) -> Tensor - return F.BLC2CBL(tensor) - - tensor = torch.rand((10, 1000, 1)) - - jit_out = jit_method(tensor) - py_out = F.BLC2CBL(tensor) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_BLC2CBL(self): - tensor = torch.rand((10, 1000, 1), device="cuda") - - self._test_script_module(tensor, transforms.BLC2CBL) - def test_torchscript_mu_law_encoding(self): @torch.jit.script def jit_method(tensor, qc): # type: (Tensor, int) -> Tensor return F.mu_law_encoding(tensor, qc) - tensor = torch.rand((10, 1)) + tensor = torch.rand((1, 10)) qc = 256 jit_out = jit_method(tensor, qc) @@ -246,7 +166,7 @@ def jit_method(tensor, qc): @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_MuLawEncoding(self): - tensor = torch.rand((10, 1), device="cuda") + tensor = torch.rand((1, 10), device="cuda") self._test_script_module(tensor, transforms.MuLawEncoding) @@ -256,7 +176,7 @@ def jit_method(tensor, qc): # type: (Tensor, int) -> Tensor return F.mu_law_expanding(tensor, qc) - tensor = torch.rand((10, 1)) + tensor = torch.rand((1, 10)) qc = 256 jit_out = jit_method(tensor, qc) @@ -266,7 +186,7 @@ def jit_method(tensor, qc): @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_MuLawExpanding(self): - tensor = torch.rand((10, 1), device="cuda") + tensor = torch.rand((1, 10), device="cuda") self._test_script_module(tensor, transforms.MuLawExpanding) diff --git a/test/test_transforms.py b/test/test_transforms.py index 4c3a55565c..1d3a41a564 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -19,191 +19,123 @@ class Tester(unittest.TestCase): # create a sinewave signal for testing - sr = 16000 + sample_rate = 16000 freq = 440 volume = .3 - sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr)) - sig.unsqueeze_(1) # (64000, 1) - sig = (sig * volume * 2**31).long() + waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)) + waveform.unsqueeze_(0) # (1, 64000) + waveform = (waveform * volume * 2**31).long() # file for stereo stft test test_dirpath, test_dir = test.common_utils.create_temp_assets_dir() - test_filepath = os.path.join(test_dirpath, "assets", - "steam-train-whistle-daniel_simon.mp3") + test_filepath = os.path.join(test_dirpath, 'assets', + 'steam-train-whistle-daniel_simon.mp3') - def test_scale(self): - - audio_orig = self.sig.clone() - result = transforms.Scale()(audio_orig) - self.assertTrue(result.min() >= -1. and result.max() <= 1.) - - maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item() - result = transforms.Scale(factor=maxminmax)(audio_orig) - - self.assertTrue((result.min() == -1. or result.max() == 1.) and - result.min() >= -1. and result.max() <= 1.) - - repr_test = transforms.Scale() - self.assertTrue(repr_test.__repr__()) + def scale(self, waveform, factor=float(2**31)): + # scales a waveform by a factor + if not waveform.is_floating_point(): + waveform = waveform.to(torch.get_default_dtype()) + return waveform / factor def test_pad_trim(self): - audio_orig = self.sig.clone() - length_orig = audio_orig.size(0) + waveform = self.waveform.clone() + length_orig = waveform.size(1) length_new = int(length_orig * 1.2) - result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig) - self.assertEqual(result.size(0), length_new) - - result = transforms.PadTrim(max_len=length_new, channels_first=True)(audio_orig.transpose(0, 1)) + result = transforms.PadTrim(max_len=length_new)(waveform) self.assertEqual(result.size(1), length_new) - audio_orig = self.sig.clone() - length_orig = audio_orig.size(0) length_new = int(length_orig * 0.8) - result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig) - - self.assertEqual(result.size(0), length_new) - - repr_test = transforms.PadTrim(max_len=length_new, channels_first=False) - self.assertTrue(repr_test.__repr__()) - - def test_downmix_mono(self): - - audio_L = self.sig.clone() - audio_R = self.sig.clone() - R_idx = int(audio_R.size(0) * 0.1) - audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx])) - - audio_Stereo = torch.cat((audio_L, audio_R), dim=1) - - self.assertTrue(audio_Stereo.size(1) == 2) - - result = transforms.DownmixMono(channels_first=False)(audio_Stereo) - - self.assertTrue(result.size(1) == 1) - - repr_test = transforms.DownmixMono(channels_first=False) - self.assertTrue(repr_test.__repr__()) - - def test_lc2cl(self): - - audio = self.sig.clone() - result = transforms.LC2CL()(audio) - self.assertTrue(result.size()[::-1] == audio.size()) - - repr_test = transforms.LC2CL() - self.assertTrue(repr_test.__repr__()) - - def test_compose(self): - - audio_orig = self.sig.clone() - length_orig = audio_orig.size(0) - length_new = int(length_orig * 1.2) - maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item() - - tset = (transforms.Scale(factor=maxminmax), - transforms.PadTrim(max_len=length_new, channels_first=False)) - result = transforms.Compose(tset)(audio_orig) - - self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.) - - self.assertTrue(result.size(0) == length_new) - - repr_test = transforms.Compose(tset) - self.assertTrue(repr_test.__repr__()) + result = transforms.PadTrim(max_len=length_new)(waveform) + self.assertEqual(result.size(1), length_new) def test_mu_law_companding(self): quantization_channels = 256 - sig = self.sig.clone() - sig = sig / torch.abs(sig).max() - self.assertTrue(sig.min() >= -1. and sig.max() <= 1.) - - sig_mu = transforms.MuLawEncoding(quantization_channels)(sig) - self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels) + waveform = self.waveform.clone() + waveform /= torch.abs(waveform).max() + self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.) - sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu) - self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) + waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform) + self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels) - repr_test = transforms.MuLawEncoding(quantization_channels) - self.assertTrue(repr_test.__repr__()) - repr_test = transforms.MuLawExpanding(quantization_channels) - self.assertTrue(repr_test.__repr__()) + waveform_exp = transforms.MuLawExpanding(quantization_channels)(waveform_mu) + self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) def test_mel2(self): top_db = 80. - s2db = transforms.SpectrogramToDB("power", top_db) + s2db = transforms.SpectrogramToDB('power', top_db) - audio_orig = self.sig.clone() # (16000, 1) - audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) - audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) + waveform = self.waveform.clone() # (1, 16000) + waveform_scaled = self.scale(waveform) # (1, 16000) mel_transform = transforms.MelSpectrogram() # check defaults - spectrogram_torch = s2db(mel_transform(audio_scaled)) # (1, 319, 40) + spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321) self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) - self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels) + self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels) # check correctness of filterbank conversion matrix - self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all()) - self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all()) + self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all()) + self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all()) # check options - kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50} + kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500, + 'hop_length': 125, 'n_fft': 800, 'n_mels': 50} mel_transform2 = transforms.MelSpectrogram(**kwargs) - spectrogram2_torch = s2db(mel_transform2(audio_scaled)) # (1, 506, 50) + spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513) self.assertTrue(spectrogram2_torch.dim() == 3) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) - self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels) - self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all()) - self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all()) + self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels) + self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all()) + self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all()) # check on multi-channel audio - x_stereo, sr_stereo = torchaudio.load(self.test_filepath) - spectrogram_stereo = s2db(mel_transform(x_stereo)) + x_stereo, sr_stereo = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394) self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.size(0) == 2) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) - self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels) + self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels) # check filterbank matrix creation - fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400) + fb_matrix_transform = transforms.MelScale( + n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400) self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) def test_mfcc(self): - audio_orig = self.sig.clone() - audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) - audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) + audio_orig = self.waveform.clone() + audio_scaled = self.scale(audio_orig) # (1, 16000) sample_rate = 16000 n_mfcc = 40 n_mels = 128 - mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, + mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho') # check defaults - torch_mfcc = mfcc_transform(audio_scaled) + torch_mfcc = mfcc_transform(audio_scaled) # (1, 40, 321) self.assertTrue(torch_mfcc.dim() == 3) - self.assertTrue(torch_mfcc.shape[2] == n_mfcc) - self.assertTrue(torch_mfcc.shape[1] == 321) + self.assertTrue(torch_mfcc.shape[1] == n_mfcc) + self.assertTrue(torch_mfcc.shape[2] == 321) # check melkwargs are passed through - melkwargs = {'ws': 200} - mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate, + melkwargs = {'win_length': 200} + mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs) - torch_mfcc2 = mfcc_transform2(audio_scaled) - self.assertTrue(torch_mfcc2.shape[1] == 641) + torch_mfcc2 = mfcc_transform2(audio_scaled) # (1, 40, 641) + self.assertTrue(torch_mfcc2.shape[2] == 641) # check norms work correctly - mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate, + mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm=None) - torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) + torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321) norm_check = torch_mfcc.clone() - norm_check[:, :, 0] *= math.sqrt(n_mels) * 2 - norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2 + norm_check[:, 0, :] *= math.sqrt(n_mels) * 2 + norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2 self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) @@ -212,45 +144,45 @@ def test_librosa_consistency(self): def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') sound, sample_rate = torchaudio.load(input_path) - sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first + sound_librosa = sound.cpu().numpy().squeeze() # (64000) # test core spectrogram - spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2) + spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2) out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=2) - out_torch = spect_transform(sound).squeeze().cpu().t() + out_torch = spect_transform(sound).squeeze().cpu() self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)) # test mel spectrogram - melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window, - hop=hop_length, n_mels=n_mels, n_fft=n_fft) + melspect_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, window_fn=torch.hann_window, + hop_length=hop_length, n_mels=n_mels, n_fft=n_fft) librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, htk=True, norm=None) librosa_mel_tensor = torch.from_numpy(librosa_mel) - torch_mel = melspect_transform(sound).squeeze().cpu().t() + torch_mel = melspect_transform(sound).squeeze().cpu() self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) # test s2db - - db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.) - db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t() + db_transform = torchaudio.transforms.SpectrogramToDB('power', 80.) + db_torch = db_transform(spect_transform(sound)).squeeze().cpu() db_librosa = librosa.core.spectrum.power_to_db(out_librosa) self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) - db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t() + db_torch = db_transform(melspect_transform(sound)).squeeze().cpu() db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) db_librosa_tensor = torch.from_numpy(db_librosa) self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)) # test MFCC - melkwargs = {'hop': hop_length, 'n_fft': n_fft} - mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, + melkwargs = {'hop_length': hop_length, 'n_fft': n_fft} + mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs) @@ -271,7 +203,7 @@ def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, s librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) - torch_mfcc = mfcc_transform(sound).squeeze().cpu().t() + torch_mfcc = mfcc_transform(sound).squeeze().cpu() self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)) @@ -308,27 +240,27 @@ def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, s def test_resample_size(self): input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') - sound, sample_rate = torchaudio.load(input_path) + waveform, sample_rate = torchaudio.load(input_path) upsample_rate = sample_rate * 2 downsample_rate = sample_rate // 2 invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo') - self.assertRaises(ValueError, invalid_resample, sound) + self.assertRaises(ValueError, invalid_resample, waveform) upsample_resample = torchaudio.transforms.Resample( sample_rate, upsample_rate, resampling_method='sinc_interpolation') - up_sampled = upsample_resample(sound) + up_sampled = upsample_resample(waveform) # we expect the upsampled signal to have twice as many samples - self.assertTrue(up_sampled.size(-1) == sound.size(-1) * 2) + self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2) downsample_resample = torchaudio.transforms.Resample( sample_rate, downsample_rate, resampling_method='sinc_interpolation') - down_sampled = downsample_resample(sound) + down_sampled = downsample_resample(waveform) # we expect the downsampled signal to have half as many samples - self.assertTrue(down_sampled.size(-1) == sound.size(-1) // 2) + self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2) if __name__ == '__main__': unittest.main() diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 301a3ebf86..065de43a5e 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -3,109 +3,48 @@ __all__ = [ - 'scale', 'pad_trim', - 'downmix_mono', - 'LC2CL', 'istft', 'spectrogram', 'create_fb_matrix', 'spectrogram_to_DB', 'create_dct', - 'BLC2CBL', 'mu_law_encoding', - 'mu_law_expanding' + 'mu_law_expanding', + 'complex_norm', + 'angle', + 'magphase', + 'phase_vocoder', ] @torch.jit.script -def scale(tensor, factor): - # type: (Tensor, int) -> Tensor - r"""Scale audio tensor from a 16-bit integer (represented as a - :class:`torch.FloatTensor`) to a floating point number between -1.0 and 1.0. - Note the 16-bit number is called the "bit depth" or "precision", not to be - confused with "bit rate". - - Args: - tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n) - factor (int): Maximum value of input tensor - - Returns: - torch.Tensor: Scaled by the scale factor - """ - if not tensor.is_floating_point(): - tensor = tensor.to(torch.float32) - - return tensor / factor - - -@torch.jit.script -def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value): - # type: (Tensor, int, int, int, float) -> Tensor - r"""Pad/trim a 2D tensor (signal or labels). +def pad_trim(waveform, max_len, fill_value): + # type: (Tensor, int, float) -> Tensor + r"""Pad/trim a 2D tensor Args: - tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n) - ch_dim (int): Dimension of channel (not size) - max_len (int): Length to which the tensor will be padded - len_dim (int): Dimension of length (not size) + waveform (torch.Tensor): Tensor of audio of size (c, n) + max_len (int): Length to which the waveform will be padded fill_value (float): Value to fill in Returns: torch.Tensor: Padded/trimmed tensor """ - if max_len > tensor.size(len_dim): - # array of [padding_left, padding_right, padding_top, padding_bottom] - # so pad similar to append (aka only right/bottom) and do not pad - # the length dimension. assumes equal sizes of padding. - padding = [max_len - tensor.size(len_dim) - if (i % 2 == 1) and (i // 2 != len_dim) - else 0 - for i in [0, 1, 2, 3]] + n = waveform.size(1) + if max_len > n: # TODO add "with torch.no_grad():" back when JIT supports it - tensor = torch.nn.functional.pad(tensor, padding, "constant", fill_value) - elif max_len < tensor.size(len_dim): - tensor = tensor.narrow(len_dim, 0, max_len) - return tensor - - -@torch.jit.script -def downmix_mono(tensor, ch_dim): - # type: (Tensor, int) -> Tensor - r"""Downmix any stereo signals to mono. - - Args: - tensor (torch.Tensor): Tensor of audio of size (c, n) or (n, c) - ch_dim (int): Dimension of channel (not size) - - Returns: - torch.Tensor: Mono signal - """ - if not tensor.is_floating_point(): - tensor = tensor.to(torch.float32) - - tensor = torch.mean(tensor, ch_dim, True) - return tensor - - -@torch.jit.script -def LC2CL(tensor): - # type: (Tensor) -> Tensor - r"""Permute a 2D tensor from samples (n, c) to (c, n). - - Args: - tensor (torch.Tensor): Tensor of audio signal with shape (n, c) + waveform = torch.nn.functional.pad(waveform, (0, max_len - n), 'constant', fill_value) + else: + waveform = waveform[:, :max_len] + return waveform - Returns: - torch.Tensor: Tensor of audio signal with shape (c, n) - """ - return tensor.transpose(0, 1).contiguous() # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved @torch.jit.ignore -def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): +def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor - return torch.stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) + return torch.stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) def istft(stft_matrix, # type: Tensor @@ -149,8 +88,8 @@ def istft(stft_matrix, # type: Tensor IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. Args: - stft_matrix (torch.Tensor): Output of stft where each row of a batch is a frequency and each - column is a window. it has a shape of either (batch, fft_size, n_frames, 2) or ( + stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each + column is a window. it has a shape of either (channel, fft_size, n_frames, 2) or ( fft_size, n_frames, 2) n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. @@ -168,20 +107,20 @@ def istft(stft_matrix, # type: Tensor Returns: torch.Tensor: Least squares estimation of the original signal of size - (batch, signal_length) or (signal_length) + (channel, signal_length) or (signal_length) """ stft_matrix_dim = stft_matrix.dim() assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) if stft_matrix_dim == 3: - # add a batch dimension + # add a channel dimension stft_matrix = stft_matrix.unsqueeze(0) device = stft_matrix.device fft_size = stft_matrix.size(1) assert (onesided and n_fft // 2 + 1 == fft_size) or (not onesided and n_fft == fft_size), ( - 'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' - + 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size)) + 'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' + + 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size)) # use stft defaults for Optionals if win_length is None: @@ -206,16 +145,16 @@ def istft(stft_matrix, # type: Tensor assert window.size(0) == n_fft # win_length and n_fft are synonymous from here on - stft_matrix = stft_matrix.transpose(1, 2) # size (batch, n_frames, fft_size, 2) + stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2) stft_matrix = torch.irfft(stft_matrix, 1, normalized, - onesided, signal_sizes=(n_fft,)) # size (batch, n_frames, n_fft) + onesided, signal_sizes=(n_fft,)) # size (channel, n_frames, n_fft) assert stft_matrix.size(2) == n_fft n_frames = stft_matrix.size(1) - ytmp = stft_matrix * window.view(1, 1, n_fft) # size (batch, n_frames, n_fft) - # each column of a batch is a frame which needs to be overlap added at the right place - ytmp = ytmp.transpose(1, 2) # size (batch, n_fft, n_frames) + ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frames, n_fft) + # each column of a channel is a frame which needs to be overlap added at the right place + ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames) eye = torch.eye(n_fft, requires_grad=False, device=device).unsqueeze(1) # size (n_fft, 1, n_fft) @@ -223,7 +162,7 @@ def istft(stft_matrix, # type: Tensor # this does overlap add where the frames of ytmp are added such that the i'th frame of # ytmp is added starting at i*hop_length in the output y = torch.nn.functional.conv_transpose1d( - ytmp, eye, stride=hop_length, padding=0) # size (batch, 1, expected_signal_len) + ytmp, eye, stride=hop_length, padding=0) # size (channel, 1, expected_signal_len) # do the same for the window function window_sq = window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) # size (1, n_fft, n_frames) @@ -246,67 +185,70 @@ def istft(stft_matrix, # type: Tensor window_envelop_lowest = window_envelop.abs().min() assert window_envelop_lowest > 1e-11, ('window overlap add min: %f' % (window_envelop_lowest)) - y = (y / window_envelop).squeeze(1) # size (batch, expected_signal_len) + y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) - if stft_matrix_dim == 3: # remove the batch dimension + if stft_matrix_dim == 3: # remove the channel dimension y = y.squeeze(0) return y @torch.jit.script -def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): +def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized): # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor r"""Create a spectrogram from a raw audio signal. Args: - sig (torch.Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of size (c, n) pad (int): Two sided padding of signal - window (torch.Tensor): Window_tensor + window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window n_fft (int): Size of fft - hop (int): Length of hop between STFT windows - ws (int): Window size - power (int) : Exponent for the magnitude spectrogram, + hop_length (int): Length of hop between STFT windows + win_length (int): Window size + power (int): Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. - normalize (bool) : Whether to normalize by magnitude after stft + normalized (bool): Whether to normalize by magnitude after stft Returns: - torch.Tensor: Channels x hops x n_fft (c, l, f), where channels - is unchanged, hops is the number of hops, and n_fft is the - number of fourier bins, which should be the window size divided - by 2 plus 1. + torch.Tensor: Channels x frequency x time (c, f, t), where channels + is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of + fourier bins, and time is the number of window hops (n_frames). """ - assert sig.dim() == 2 + assert waveform.dim() == 2 if pad > 0: # TODO add "with torch.no_grad():" back when JIT supports it - sig = torch.nn.functional.pad(sig, (pad, pad), "constant") + waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") # default values are consistent with librosa.core.spectrum._spectrogram - spec_f = _stft(sig, n_fft, hop, ws, window, - True, 'reflect', False, True).transpose(1, 2) + spec_f = _stft(waveform, n_fft, hop_length, win_length, window, + True, 'reflect', False, True) - if normalize: + if normalized: spec_f /= window.pow(2).sum().sqrt() - spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft) + spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor return spec_f @torch.jit.script -def create_fb_matrix(n_stft, f_min, f_max, n_mels): +def create_fb_matrix(n_freqs, f_min, f_max, n_mels): # type: (int, float, float, int) -> Tensor r""" Create a frequency bin conversion matrix. Args: - n_stft (int): Number of filter banks from spectrogram + n_freqs (int): Number of frequencies to highlight/apply f_min (float): Minimum frequency f_max (float): Maximum frequency - n_mels (int): Number of mel bins + n_mels (int): Number of mel filterbanks Returns: - torch.Tensor: Triangular filter banks (fb matrix) + torch.Tensor: Triangular filter banks (fb matrix) of size (`n_freqs`, `n_mels`) + meaning number of frequencies to highlight/apply to x the number of filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., `n_freqs`), the applied result would be + `A * create_fb_matrix(A.size(-1), ...)`. """ - # get stft freq bins - stft_freqs = torch.linspace(f_min, f_max, n_stft) + # freq bins + freqs = torch.linspace(f_min, f_max, n_freqs) # calculate mel freq bins # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.)) @@ -316,17 +258,17 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels): f_pts = 700. * (10**(m_pts / 2595.) - 1.) # calculate the difference between each mel point and each stft freq point in hertz f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) - slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2) + slopes = f_pts.unsqueeze(0) - freqs.unsqueeze(1) # (n_freqs, n_mels + 2) # create overlapping triangles - z = torch.zeros(1) - down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels) - up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels) - fb = torch.max(z, torch.min(down_slopes, up_slopes)) + zero = torch.zeros(1) + down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) + up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) + fb = torch.max(zero, torch.min(down_slopes, up_slopes)) return fb @torch.jit.script -def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): +def spectrogram_to_DB(specgram, multiplier, amin, db_multiplier, top_db=None): # type: (Tensor, float, float, float, Optional[float]) -> Tensor r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. @@ -335,72 +277,57 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): a full clip. Args: - spec (torch.Tensor): Normal STFT + specgram (torch.Tensor): Normal STFT of size (c, f, t) multiplier (float): Use 10. for power and 20. for amplitude - amin (float): Number to clamp spec + amin (float): Number to clamp specgram db_multiplier (float): Log10(max(reference value and amin)) - top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number + top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number is 80. Returns: - torch.Tensor: Spectrogram in DB + torch.Tensor: Spectrogram in DB of size (c, f, t) """ - spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin)) - spec_db -= multiplier * db_multiplier + specgram_db = multiplier * torch.log10(torch.clamp(specgram, min=amin)) + specgram_db -= multiplier * db_multiplier if top_db is not None: - new_spec_db_max = torch.tensor(float(spec_db.max()) - top_db, dtype=spec_db.dtype, device=spec_db.device) - spec_db = torch.max(spec_db, new_spec_db_max) + new_spec_db_max = torch.tensor(float(specgram_db.max()) - top_db, + dtype=specgram_db.dtype, device=specgram_db.device) + specgram_db = torch.max(specgram_db, new_spec_db_max) - return spec_db + return specgram_db @torch.jit.script def create_dct(n_mfcc, n_mels, norm): # type: (int, int, Optional[str]) -> Tensor - r"""Creates a DCT transformation matrix with shape (num_mels, num_mfcc), + r"""Creates a DCT transformation matrix with shape (`n_mels`, `n_mfcc`), normalized depending on norm. Args: - n_mfcc (int) : Number of mfc coefficients to retain - n_mels (int): Number of MEL bins - norm (Optional[str]) : Norm to use (either 'ortho' or None) + n_mfcc (int): Number of mfc coefficients to retain + n_mels (int): Number of mel filterbanks + norm (Optional[str]): Norm to use (either 'ortho' or None) Returns: - torch.Tensor: The transformation matrix, to be right-multiplied to row-wise data. + torch.Tensor: The transformation matrix, to be right-multiplied to + row-wise data of size (`n_mels`, `n_mfcc`). """ - outdim = n_mfcc - dim = n_mels # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II - n = torch.arange(dim) - k = torch.arange(outdim)[:, None] - dct = torch.cos(math.pi / float(dim) * (n + 0.5) * k) + n = torch.arange(float(n_mels)) + k = torch.arange(float(n_mfcc)).unsqueeze(1) + dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels) if norm is None: dct *= 2.0 else: assert norm == 'ortho' dct[0] *= 1.0 / math.sqrt(2.0) - dct *= math.sqrt(2.0 / float(dim)) + dct *= math.sqrt(2.0 / float(n_mels)) return dct.t() @torch.jit.script -def BLC2CBL(tensor): - # type: (Tensor) -> Tensor - r"""Permute a 3D tensor from Bands x Sample length x Channels to Channels x - Bands x Samples length. - - Args: - tensor (torch.Tensor): Tensor of spectrogram with shape (b, l, c) - - Returns: - torch.Tensor: Tensor of spectrogram with shape (c, b, l) - """ - return tensor.permute(2, 0, 1).contiguous() - - -@torch.jit.script -def mu_law_encoding(x, qc): +def mu_law_encoding(x, quantization_channels): # type: (Tensor, int) -> Tensor r"""Encode signal based on mu-law companding. For more info see the `Wikipedia Entry `_ @@ -410,13 +337,12 @@ def mu_law_encoding(x, qc): Args: x (torch.Tensor): Input tensor - qc (int): Number of channels (i.e. quantization channels) + quantization_channels (int): Number of channels Returns: torch.Tensor: Input after mu-law companding """ - assert isinstance(x, torch.Tensor), 'mu_law_encoding expects a Tensor' - mu = qc - 1. + mu = quantization_channels - 1. if not x.is_floating_point(): x = x.to(torch.float) mu = torch.tensor(mu, dtype=x.dtype) @@ -427,7 +353,7 @@ def mu_law_encoding(x, qc): @torch.jit.script -def mu_law_expanding(x_mu, qc): +def mu_law_expanding(x_mu, quantization_channels): # type: (Tensor, int) -> Tensor r"""Decode mu-law encoded signal. For more info see the `Wikipedia Entry `_ @@ -437,13 +363,12 @@ def mu_law_expanding(x_mu, qc): Args: x_mu (torch.Tensor): Input tensor - qc (int): Number of channels (i.e. quantization channels) + quantization_channels (int): Number of channels Returns: torch.Tensor: Input after decoding """ - assert isinstance(x_mu, torch.Tensor), 'mu_law_expanding expects a Tensor' - mu = qc - 1. + mu = quantization_channels - 1. if not x_mu.is_floating_point(): x_mu = x_mu.to(torch.float) mu = torch.tensor(mu, dtype=x_mu.dtype) @@ -452,71 +377,15 @@ def mu_law_expanding(x_mu, qc): return x -def stft(waveforms, fft_length, hop_length=None, win_length=None, window=None, - center=True, pad_mode='reflect', normalized=False, onesided=True): - """Compute a short time Fourier transform of the input waveform(s). - It wraps `torch.stft` after reshaping the input audio to allow for `waveforms` that `.dim()` >= 3. - It follows most of the `torch.stft` default values, but for `window`, which defaults to hann window. - - Args: - waveforms (torch.Tensor): Audio signal of size `(*, channel, time)` - fft_length (int): FFT size [sample]. - hop_length (int): Hop size [sample] between STFT frames. - (Defaults to `fft_length // 4`, 75%-overlapping windows by `torch.stft`). - win_length (int): Size of STFT window. (Defaults to `fft_length` by `torch.stft`). - window (torch.Tensor): window function. (Defaults to Hann Window of size `win_length` *unlike* `torch.stft`). - center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered - at time `t * hop_length`. (Defaults to `True` by `torch.stft`) - pad_mode (str): padding method (see `torch.nn.functional.pad`). (Defaults to `'reflect'` by `torch.stft`). - normalized (bool): Whether the results are normalized. (Defaults to `False` by `torch.stft`). - onesided (bool): Whether the half + 1 frequency bins are returned to removethe symmetric part of STFT - of real-valued signal. (Defaults to `True` by `torch.stft`). - - Returns: - torch.Tensor: `(*, channel, num_freqs, time, complex=2)` - - Example: - >>> waveforms = torch.randn(16, 2, 10000) # (batch, channel, time) - >>> x = stft(waveforms, 2048, 512) - >>> x.shape - torch.Size([16, 2, 1025, 20]) - """ - leading_dims = waveforms.shape[:-1] - - waveforms = waveforms.reshape(-1, waveforms.size(-1)) - - if window is None: - if win_length is None: - window = torch.hann_window(fft_length) - else: - window = torch.hann_window(win_length) - - complex_specgrams = torch.stft(waveforms, - n_fft=fft_length, - hop_length=hop_length, - win_length=win_length, - window=window, - center=center, - pad_mode=pad_mode, - normalized=normalized, - onesided=onesided) - - complex_specgrams = complex_specgrams.reshape( - leading_dims + - complex_specgrams.shape[1:]) - - return complex_specgrams - - def complex_norm(complex_tensor, power=1.0): - """Compute the norm of complex tensor input + r"""Compute the norm of complex tensor input. Args: - complex_tensor (Tensor): Tensor shape of `(*, complex=2)` - power (float): Power of the norm. Defaults to `1.0`. + complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + power (float): Power of the norm. (Default: `1.0`). Returns: - Tensor: power of the normed input tensor, shape of `(*, )` + torch.Tensor: Power of the normed input tensor. Shape of `(*, )` """ if power == 1.0: return torch.norm(complex_tensor, 2, -1) @@ -524,16 +393,26 @@ def complex_norm(complex_tensor, power=1.0): def angle(complex_tensor): - """ - Return angle of a complex tensor with shape (*, 2). + r"""Compute the angle of complex tensor input. + + Args: + complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + + Return: + torch.Tensor: Angle of a complex tensor. Shape of `(*, )` """ return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) def magphase(complex_tensor, power=1.): - """ - Separate a complex-valued spectrogram with shape (*,2) - into its magnitude and phase. + r"""Separate a complex-valued spectrogram with shape (*,2) into its magnitude and phase. + + Args: + complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + power (float): Power of the norm. (Default: `1.0`) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex_tensor """ mag = complex_norm(complex_tensor, power) phase = angle(complex_tensor) @@ -541,20 +420,16 @@ def magphase(complex_tensor, power=1.): def phase_vocoder(complex_specgrams, rate, phase_advance): - """ - Phase vocoder. Given a STFT tensor, speed up in time - without modifying pitch by a factor of `rate`. + r"""Given a STFT tensor, speed up in time without modifying pitch by a + factor of `rate`. Args: - complex_specgrams (Tensor): - (*, channel, num_freqs, time, complex=2) - rate (float): Speed-up factor. - phase_advance (Tensor): Expected phase advance in - each bin. (num_freqs, 1). + complex_specgrams (torch.Tensor): Size of (*, c, f, t, complex=2) + rate (float): Speed-up factor + phase_advance (torch.Tensor): Expected phase advance in each bin. Size of (f, 1) Returns: - complex_specgrams_stretch (Tensor): - (*, channel, num_freqs, ceil(time/rate), complex=2). + complex_specgrams_stretch (torch.Tensor): Size of (*, c, f, ceil(t/rate), complex=2) Example: >>> num_freqs, hop_length = 1025, 512 diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 9176540910..e1c821ea2f 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -7,314 +7,205 @@ from .compliance import kaldi -# TODO remove this class -class Compose(object): - """Composes several transforms together. - - Args: - transforms (list of ``Transform`` objects): list of transforms to compose. - - Example: - >>> transforms.Compose([ - >>> transforms.Scale(), - >>> transforms.PadTrim(max_len=16000), - >>> ]) - """ - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, audio): - for t in self.transforms: - audio = t(audio) - return audio - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - -class Scale(torch.jit.ScriptModule): - """Scale audio tensor from a 16-bit integer (represented as a FloatTensor) - to a floating point number between -1.0 and 1.0. Note the 16-bit number is - called the "bit depth" or "precision", not to be confused with "bit rate". - - Args: - factor (int): maximum value of input tensor. default: 16-bit depth - - """ - __constants__ = ['factor'] - - def __init__(self, factor=2**31): - super(Scale, self).__init__() - self.factor = factor - - @torch.jit.script_method - def forward(self, tensor): - """ - - Args: - tensor (Tensor): Tensor of audio of size (Samples x Channels) - - Returns: - Tensor: Scaled by the scale factor. (default between -1.0 and 1.0) - - """ - return F.scale(tensor, self.factor) - - def __repr__(self): - return self.__class__.__name__ + '()' - - class PadTrim(torch.jit.ScriptModule): - """Pad/Trim a 2d-Tensor (Signal or Labels) + r"""Pad/Trim a 2D tensor Args: - tensor (Tensor): Tensor of audio of size (n x c) or (c x n) - max_len (int): Length to which the tensor will be padded - channels_first (bool): Pad for channels first tensors. Default: `True` - + max_len (int): Length to which the waveform will be padded + fill_value (float): Value to fill in """ - __constants__ = ['max_len', 'fill_value', 'len_dim', 'ch_dim'] + __constants__ = ['max_len', 'fill_value'] - def __init__(self, max_len, fill_value=0., channels_first=True): + def __init__(self, max_len, fill_value=0.): super(PadTrim, self).__init__() self.max_len = max_len self.fill_value = fill_value - self.len_dim, self.ch_dim = int(channels_first), int(not channels_first) - - @torch.jit.script_method - def forward(self, tensor): - """ - - Returns: - Tensor: (c x n) or (n x c) - - """ - return F.pad_trim(tensor, self.ch_dim, self.max_len, self.len_dim, self.fill_value) - - def __repr__(self): - return self.__class__.__name__ + '(max_len={0})'.format(self.max_len) - - -class DownmixMono(torch.jit.ScriptModule): - """Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with - the `channels` effect instead of this transformation. - - Inputs: - tensor (Tensor): Tensor of audio of size (c x n) or (n x c) - channels_first (bool): Downmix across channels dimension. Default: `True` - - Returns: - tensor (Tensor) (Samples x 1): - - """ - __constants__ = ['ch_dim'] - - def __init__(self, channels_first=None): - super(DownmixMono, self).__init__() - self.ch_dim = int(not channels_first) @torch.jit.script_method - def forward(self, tensor): - return F.downmix_mono(tensor, self.ch_dim) - - def __repr__(self): - return self.__class__.__name__ + '()' - - -class LC2CL(torch.jit.ScriptModule): - """Permute a 2d tensor from samples (n x c) to (c x n) - """ - - def __init__(self): - super(LC2CL, self).__init__() - - @torch.jit.script_method - def forward(self, tensor): - """ - + def forward(self, waveform): + r""" Args: - tensor (Tensor): Tensor of audio signal with shape (LxC) + waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: - tensor (Tensor): Tensor of audio signal with shape (CxL) + Tensor: Tensor of size (c, `max_len`) """ - return F.LC2CL(tensor) - - def __repr__(self): - return self.__class__.__name__ + '()' - - -def SPECTROGRAM(*args, **kwargs): - warn("SPECTROGRAM has been renamed to Spectrogram") - return Spectrogram(*args, **kwargs) + return F.pad_trim(waveform, self.max_len, self.fill_value) class Spectrogram(torch.jit.ScriptModule): - """Create a spectrogram from a raw audio signal + r"""Create a spectrogram from a audio signal Args: - n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins - ws (int): window size. default: n_fft - hop (int, optional): length of hop between STFT windows. default: ws // 2 - pad (int): two sided padding of signal - window (torch windowing function): default: torch.hann_window - power (int > 0 ) : Exponent for the magnitude spectrogram, - e.g., 1 for energy, 2 for power, etc. - normalize (bool) : whether to normalize by magnitude after stft - wkwargs (dict, optional): arguments for window function + n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins + win_length (int): Window size. (Default: `n_fft`) + hop_length (int, optional): Length of hop between STFT windows. ( + Default: `win_length // 2`) + pad (int): Two sided padding of signal. (Default: 0) + window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) + power (int): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. + normalized (bool): Whether to normalize by magnitude after stft. (Default: `False`) + wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) """ - __constants__ = ['n_fft', 'ws', 'hop', 'pad', 'power', 'normalize'] + __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] - def __init__(self, n_fft=400, ws=None, hop=None, - pad=0, window=torch.hann_window, - power=2, normalize=False, wkwargs=None): + def __init__(self, n_fft=400, win_length=None, hop_length=None, + pad=0, window_fn=torch.hann_window, + power=2, normalized=False, wkwargs=None): super(Spectrogram, self).__init__() self.n_fft = n_fft # number of fft bins. the returned STFT result will have n_fft // 2 + 1 # number of frequecies due to onesided=True in torch.stft - self.ws = ws if ws is not None else n_fft - self.hop = hop if hop is not None else self.ws // 2 - window = window(self.ws) if wkwargs is None else window(self.ws, **wkwargs) + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) self.window = torch.jit.Attribute(window, torch.Tensor) self.pad = pad self.power = power - self.normalize = normalize + self.normalized = normalized @torch.jit.script_method - def forward(self, sig): - """ + def forward(self, waveform): + r""" Args: - sig (Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: - spec_f (Tensor): channels x hops x n_fft (c, l, f), where channels - is unchanged, hops is the number of hops, and n_fft is the - number of fourier bins, which should be the window size divided - by 2 plus 1. - + torch.Tensor: Channels x frequency x time (c, f, t), where channels + is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of + fourier bins, and time is the number of window hops (n_frames). """ - return F.spectrogram(sig, self.pad, self.window, self.n_fft, self.hop, - self.ws, self.power, self.normalize) - - -def F2M(*args, **kwargs): - warn("F2M has been renamed to MelScale") - return MelScale(*args, **kwargs) + return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, + self.win_length, self.power, self.normalized) class MelScale(torch.jit.ScriptModule): - """This turns a normal STFT into a mel frequency STFT, using a conversion + r"""This turns a normal STFT into a mel frequency STFT, using a conversion matrix. This uses triangular filter banks. User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). Args: - n_mels (int): number of mel bins - sr (int): sample rate of audio signal - f_max (float, optional): maximum frequency. default: `sr` // 2 - f_min (float): minimum frequency. default: 0 - n_stft (int, optional): number of filter banks from stft. Calculated from first input + n_mels (int): Number of mel filterbanks. (Default: 128) + sample_rate (int): Sample rate of audio signal. (Default: 16000) + f_min (float): Minimum frequency. (Default: 0.) + f_max (float, optional): Maximum frequency. (Default: `sample_rate // 2`) + n_stft (int, optional): Number of bins in STFT. Calculated from first input if `None` is given. See `n_fft` in `Spectrogram`. """ - __constants__ = ['n_mels', 'sr', 'f_min', 'f_max'] + __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] - def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None): + def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None): super(MelScale, self).__init__() self.n_mels = n_mels - self.sr = sr - self.f_max = f_max if f_max is not None else float(sr // 2) + self.sample_rate = sample_rate + self.f_max = f_max if f_max is not None else float(sample_rate // 2) + assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) self.f_min = f_min fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels) self.fb = torch.jit.Attribute(fb, torch.Tensor) @torch.jit.script_method - def forward(self, spec_f): + def forward(self, specgram): + r""" + Args: + specgram (torch.Tensor): a spectrogram STFT of size (c, f, t) + + Returns: + torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) + """ if self.fb.numel() == 0: - tmp_fb = F.create_fb_matrix(spec_f.size(2), self.f_min, self.f_max, self.n_mels) + tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels) # Attributes cannot be reassigned outside __init__ so workaround self.fb.resize_(tmp_fb.size()) self.fb.copy_(tmp_fb) - spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels) - return spec_m + + # (c, f, t).transpose(...) dot (f, n_mels) -> (c, t, n_mels).transpose(...) + mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) + return mel_specgram class SpectrogramToDB(torch.jit.ScriptModule): - """Turns a spectrogram from the power/amplitude scale to the decibel scale. + r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. This output depends on the maximum value in the input spectrogram, and so may return different values for an audio clip split into snippets vs. a a full clip. Args: - stype (str): scale of input spectrogram ("power" or "magnitude"). The - power being the elementwise square of the magnitude. default: "power" + stype (str): scale of input spectrogram ('power' or 'magnitude'). The + power being the elementwise square of the magnitude. (Default: 'power') top_db (float, optional): minimum negative cut-off in decibels. A reasonable number is 80. """ __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] - def __init__(self, stype="power", top_db=None): + def __init__(self, stype='power', top_db=None): super(SpectrogramToDB, self).__init__() self.stype = torch.jit.Attribute(stype, str) if top_db is not None and top_db < 0: raise ValueError('top_db must be positive value') self.top_db = torch.jit.Attribute(top_db, Optional[float]) - self.multiplier = 10. if stype == "power" else 20. + self.multiplier = 10.0 if stype == 'power' else 20.0 self.amin = 1e-10 - self.ref_value = 1. + self.ref_value = 1.0 self.db_multiplier = math.log10(max(self.amin, self.ref_value)) @torch.jit.script_method - def forward(self, spec): - # numerically stable implementation from librosa - # https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html - return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db) + def forward(self, specgram): + r"""Numerically stable implementation from Librosa + https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html + + Args: + specgram (torch.Tensor): STFT of size (c, f, t) + + Returns: + torch.Tensor: STFT after changing scale of size (c, f, t) + """ + return F.spectrogram_to_DB(specgram, self.multiplier, self.amin, self.db_multiplier, self.top_db) class MFCC(torch.jit.ScriptModule): - """Create the Mel-frequency cepstrum coefficients from an audio signal + r"""Create the Mel-frequency cepstrum coefficients from an audio signal - By default, this calculates the MFCC on the DB-scaled Mel spectrogram. - This is not the textbook implementation, but is implemented here to - give consistency with librosa. + By default, this calculates the MFCC on the DB-scaled Mel spectrogram. + This is not the textbook implementation, but is implemented here to + give consistency with librosa. - This output depends on the maximum value in the input spectrogram, and so - may return different values for an audio clip split into snippets vs. a - a full clip. + This output depends on the maximum value in the input spectrogram, and so + may return different values for an audio clip split into snippets vs. a + a full clip. - Args: - sr (int) : sample rate of audio signal - n_mfcc (int) : number of mfc coefficients to retain - dct_type (int) : type of DCT (discrete cosine transform) to use - norm (string, optional) : norm to use - log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled + Args: + sample_rate (int): Sample rate of audio signal. (Default: 16000) + n_mfcc (int): Number of mfc coefficients to retain + dct_type (int): type of DCT (discrete cosine transform) to use + norm (string, optional): norm to use + log_mels (bool): whether to use log-mel spectrograms instead of db-scaled melkwargs (dict, optional): arguments for MelSpectrogram """ - __constants__ = ['sr', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] + __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] - def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, + def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, melkwargs=None): super(MFCC, self).__init__() supported_dct_types = [2] if dct_type not in supported_dct_types: raise ValueError('DCT type not supported'.format(dct_type)) - self.sr = sr + self.sample_rate = sample_rate self.n_mfcc = n_mfcc self.dct_type = dct_type self.norm = torch.jit.Attribute(norm, Optional[str]) - self.top_db = 80. - self.s2db = SpectrogramToDB("power", self.top_db) + self.top_db = 80.0 + self.spectrogram_to_DB = SpectrogramToDB('power', self.top_db) if melkwargs is not None: - self.MelSpectrogram = MelSpectrogram(sr=self.sr, **melkwargs) + self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs) else: - self.MelSpectrogram = MelSpectrogram(sr=self.sr) + self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate) if self.n_mfcc > self.MelSpectrogram.n_mels: raise ValueError('Cannot select more MFCC coefficients than # mel bins') @@ -323,29 +214,28 @@ def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False self.log_mels = log_mels @torch.jit.script_method - def forward(self, sig): - """ + def forward(self, waveform): + r""" Args: - sig (Tensor): Tensor of audio of size (channels [c], samples [n]) + waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: - spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels - is unchanged, hops is the number of hops, and n_mels is the - number of mel bins. + torch.Tensor: specgram_mel_db of size (c, `n_mfcc`, t) """ - mel_spect = self.MelSpectrogram(sig) + mel_specgram = self.MelSpectrogram(waveform) if self.log_mels: log_offset = 1e-6 - mel_spect = torch.log(mel_spect + log_offset) + mel_specgram = torch.log(mel_specgram + log_offset) else: - mel_spect = self.s2db(mel_spect) - mfcc = torch.matmul(mel_spect, self.dct_mat) + mel_specgram = self.spectrogram_to_DB(mel_specgram) + # (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...) + mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) return mfcc class MelSpectrogram(torch.jit.ScriptModule): - """Create MEL Spectrograms from a raw audio signal using the stft - function in PyTorch. + r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram + and MelScale. Sources: * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe @@ -353,87 +243,58 @@ class MelSpectrogram(torch.jit.ScriptModule): * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html Args: - sr (int): sample rate of audio signal - ws (int): window size - hop (int, optional): length of hop between STFT windows. default: `ws` // 2 - n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1 - f_max (float, optional): maximum frequency. default: `sr` // 2 - f_min (float): minimum frequency. default: 0 - pad (int): two sided padding of signal - n_mels (int): number of MEL bins - window (torch windowing function): default: `torch.hann_window` - wkwargs (dict, optional): arguments for window function + sample_rate (int): Sample rate of audio signal. (Default: 16000) + win_length (int): Window size. (Default: `n_fft`) + hop_length (int, optional): Length of hop between STFT windows. ( + Default: `win_length // 2`) + n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins + f_min (float): Minimum frequency. (Default: 0.) + f_max (float, optional): Maximum frequency. (Default: `None`) + pad (int): Two sided padding of signal. (Default: 0) + n_mels (int): Number of mel filterbanks. (Default: 128) + window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) + wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) Example: - >>> sig, sr = torchaudio.load("test.wav", normalization=True) - >>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, l, m) + >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) + >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t) """ - __constants__ = ['sr', 'n_fft', 'ws', 'hop', 'pad', 'n_mels', 'f_min'] + __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] - def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None, - pad=0, n_mels=128, window=torch.hann_window, wkwargs=None): + def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None, + pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None): super(MelSpectrogram, self).__init__() - self.sr = sr + self.sample_rate = sample_rate self.n_fft = n_fft - self.ws = ws if ws is not None else n_fft - self.hop = hop if hop is not None else self.ws // 2 + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.pad = pad self.n_mels = n_mels # number of mel frequency bins self.f_max = torch.jit.Attribute(f_max, Optional[float]) self.f_min = f_min - self.spec = Spectrogram(n_fft=self.n_fft, ws=self.ws, hop=self.hop, - pad=self.pad, window=window, power=2, - normalize=False, wkwargs=wkwargs) - self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min) + self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, + hop_length=self.hop_length, + pad=self.pad, window_fn=window_fn, power=2, + normalized=False, wkwargs=wkwargs) + self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max) @torch.jit.script_method - def forward(self, sig): - """ + def forward(self, waveform): + r""" Args: - sig (Tensor): Tensor of audio of size (channels [c], samples [n]) + waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: - spec_mel (Tensor): channels x hops x n_mels (c, l, m), where channels - is unchanged, hops is the number of hops, and n_mels is the - number of mel bins. - + torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) """ - spec = self.spec(sig) - spec_mel = self.fm(spec) - return spec_mel - - -def MEL(*args, **kwargs): - raise DeprecationWarning("MEL has been removed from the library please use MelSpectrogram or librosa") - - -class BLC2CBL(torch.jit.ScriptModule): - """Permute a 3d tensor from Bands x Sample length x Channels to Channels x - Bands x Samples length - """ - - def __init__(self): - super(BLC2CBL, self).__init__() - - @torch.jit.script_method - def forward(self, tensor): - """ - - Args: - tensor (Tensor): Tensor of spectrogram with shape (BxLxC) - - Returns: - tensor (Tensor): Tensor of spectrogram with shape (CxBxL) - - """ - return F.BLC2CBL(tensor) - - def __repr__(self): - return self.__class__.__name__ + '()' + specgram = self.spectrogram(waveform) + mel_specgram = self.mel_scale(specgram) + return mel_specgram class MuLawEncoding(torch.jit.ScriptModule): - """Encode signal based on mu-law companding. For more info see the + r"""Encode signal based on mu-law companding. For more info see the `Wikipedia Entry `_ This algorithm assumes the signal has been scaled to between -1 and 1 and @@ -441,33 +302,27 @@ class MuLawEncoding(torch.jit.ScriptModule): Args: quantization_channels (int): Number of channels. default: 256 - """ - __constants__ = ['qc'] + __constants__ = ['quantization_channels'] def __init__(self, quantization_channels=256): super(MuLawEncoding, self).__init__() - self.qc = quantization_channels + self.quantization_channels = quantization_channels @torch.jit.script_method def forward(self, x): - """ - + r""" Args: - x (FloatTensor/LongTensor) + x (torch.Tensor): A signal to be encoded Returns: - x_mu (LongTensor) - + x_mu (torch.Tensor): An encoded signal """ - return F.mu_law_encoding(x, self.qc) - - def __repr__(self): - return self.__class__.__name__ + '()' + return F.mu_law_encoding(x, self.quantization_channels) class MuLawExpanding(torch.jit.ScriptModule): - """Decode mu-law encoded signal. For more info see the + r"""Decode mu-law encoded signal. For more info see the `Wikipedia Entry `_ This expects an input with values between 0 and quantization_channels - 1 @@ -475,33 +330,27 @@ class MuLawExpanding(torch.jit.ScriptModule): Args: quantization_channels (int): Number of channels. default: 256 - """ - __constants__ = ['qc'] + __constants__ = ['quantization_channels'] def __init__(self, quantization_channels=256): super(MuLawExpanding, self).__init__() - self.qc = quantization_channels + self.quantization_channels = quantization_channels @torch.jit.script_method def forward(self, x_mu): - """ - + r""" Args: - x_mu (Tensor) + x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded Returns: - x (Tensor) - + torch.Tensor: The signal decoded """ - return F.mu_law_expanding(x_mu, self.qc) - - def __repr__(self): - return self.__class__.__name__ + '()' + return F.mu_law_expanding(x_mu, self.quantization_channels) class Resample(torch.nn.Module): - """Resamples a signal from one frequency to another. A resampling method can + r"""Resamples a signal from one frequency to another. A resampling method can be given. Args: @@ -516,15 +365,15 @@ def __init__(self, orig_freq, new_freq, resampling_method='sinc_interpolation'): self.new_freq = new_freq self.resampling_method = resampling_method - def forward(self, sig): - """ + def forward(self, waveform): + r""" Args: - sig (Tensor): the input signal of size (c, n) + waveform (torch.Tensor): The input signal of size (c, n) Returns: - Tensor: output signal of size (c, m) + torch.Tensor: Output signal of size (c, m) """ if self.resampling_method == 'sinc_interpolation': - return kaldi.resample_waveform(sig, self.orig_freq, self.new_freq) + return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) raise ValueError('Invalid resampling method: %s' % (self.resampling_method))