diff --git a/docs/source/functional.rst b/docs/source/functional.rst index 9968bfadf7..faa278c80b 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -82,3 +82,28 @@ Functions to perform common audio operations. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: highpass_biquad + +:hidden:`equalizer_biquad` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: equalizer_biquad + +:hidden:`mask_along_axis` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: mask_along_axis + +:hidden:`mask_along_axis_iid` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: mask_along_axis_iid + +:hidden:`compute_deltas` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: compute_deltas + +:hidden:`detect_pitch_frequency` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: detect_pitch_frequency diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index ac2c733ac6..856b0c84e6 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -64,3 +64,38 @@ Transforms are common audio transforms. They can be chained together using :clas .. autoclass:: Resample .. automethod:: torchaudio._docs.Resample.forward + +:hidden:`ComplexNorm` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ComplexNorm + + .. automethod:: torchaudio._docs.ComplexNorm.forward + +:hidden:`ComputeDeltas` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ComputeDeltas + + .. automethod:: torchaudio._docs.ComputeDeltas.forward + +:hidden:`TimeStretch` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: TimeStretch + + .. automethod:: torchaudio._docs.TimeStretch.forward + +:hidden:`FrequencyMasking` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: FrequencyMasking + + .. automethod:: torchaudio._docs.FrequencyMasking.forward + +:hidden:`TimeMasking` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: TimeMasking + + .. automethod:: torchaudio._docs.TimeMasking.forward diff --git a/test/test_transforms.py b/test/test_transforms.py index bc82f0a673..c22169741f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4,7 +4,6 @@ import torch import torchaudio -import torchaudio.augmentations as A import torchaudio.transforms as transforms import torchaudio.functional as F from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY @@ -424,15 +423,15 @@ def test_scriptmodule_TimeStretch(self): hop_length = 512 fixed_rate = 1.3 tensor = torch.rand((10, 2, n_freq, 10, 2)) - _test_script_module(A.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate) + _test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate) def test_scriptmodule_FrequencyMasking(self): tensor = torch.rand((10, 2, 50, 10, 2)) - _test_script_module(A.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False) + _test_script_module(transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False) def test_scriptmodule_TimeMasking(self): tensor = torch.rand((10, 2, 50, 10, 2)) - _test_script_module(A.TimeMasking, tensor, time_mask_param=30, iid_masks=False) + _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False) if __name__ == '__main__': diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py deleted file mode 100644 index ed1573bdac..0000000000 --- a/torchaudio/augmentations.py +++ /dev/null @@ -1,128 +0,0 @@ -import math -import torch - -from . import functional as F - -__all__ = [ - 'TimeStretch', - 'FrequencyMasking', - 'TimeMasking' -] - - -class TimeStretch(torch.jit.ScriptModule): - r"""Stretch stft in time without modifying pitch for a given rate. - - Args: - hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``) - n_freq (int, optional): number of filter banks from stft. (Default: ``201``) - fixed_rate (float): rate to speed up or slow down by. - If None is provided, rate must be passed to the forward method. (Default: ``None``) - """ - __constants__ = ['fixed_rate'] - - def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): - super(TimeStretch, self).__init__() - - self.fixed_rate = fixed_rate - - n_fft = (n_freq - 1) * 2 - hop_length = hop_length if hop_length is not None else n_fft // 2 - phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] - self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) - - def forward(self, complex_specgrams, overriding_rate=None): - # type: (Tensor, Optional[float]) -> Tensor - r""" - Args: - complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2) - overriding_rate (float or None): speed up to apply to this batch. - If no rate is passed, use ``self.fixed_rate`` - - Returns: - (Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2) - """ - assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)" - - if overriding_rate is None: - rate = self.fixed_rate - if rate is None: - raise ValueError("If no fixed_rate is specified" - ", must pass a valid rate to the forward method.") - else: - rate = overriding_rate - - if rate == 1.0: - return complex_specgrams - - shape = complex_specgrams.size() - complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) - complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance) - - return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) - - -class _AxisMasking(torch.nn.Module): - r""" - Apply masking to a spectrogram. - Args: - mask_param (int): Maximum possible length of the mask - axis: What dimension the mask is applied on - iid_masks (bool): Applies iid masks to each of the examples in the batch dimension - """ - __constants__ = ['mask_param', 'axis', 'iid_masks'] - - def __init__(self, mask_param, axis, iid_masks): - - super(_AxisMasking, self).__init__() - self.mask_param = mask_param - self.axis = axis - self.iid_masks = iid_masks - - def forward(self, specgram, mask_value=0.): - # type: (Tensor, float) -> Tensor - r""" - Args: - specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time) - - Returns: - torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time) - """ - - # if iid_masks flag marked and specgram has a batch dimension - if self.iid_masks and specgram.dim() == 4: - return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) - else: - shape = specgram.size() - specgram = specgram.reshape([-1] + list(shape[-2:])) - specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) - - return specgram.reshape(shape[:-2] + specgram.shape[-2:]) - - -class FrequencyMasking(_AxisMasking): - r""" - Apply masking to a spectrogram in the frequency domain. - Args: - freq_mask_param (int): maximum possible length of the mask. - Indices uniformly sampled from [0, freq_mask_param). - iid_masks (bool): weather to apply the same mask to all - the examples/channels in the batch. (Default: False) - """ - - def __init__(self, freq_mask_param, iid_masks=False): - super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks) - - -class TimeMasking(_AxisMasking): - r""" - Apply masking to a spectrogram in the time domain. - Args: - time_mask_param (int): maximum possible length of the mask. - Indices uniformly sampled from [0, time_mask_param). - iid_masks (bool): weather to apply the same mask to all - the examples/channels in the batch. Defaults to False. - """ - - def __init__(self, time_mask_param, iid_masks=False): - super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index c7e7bfa0ba..5f8fc59d38 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -16,7 +16,10 @@ 'MuLawEncoding', 'MuLawDecoding', 'Resample', - 'ComplexNorm' + 'ComplexNorm', + 'TimeStretch', + 'FrequencyMasking', + 'TimeMasking', ] @@ -408,3 +411,121 @@ def forward(self, specgram): deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time) """ return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) + + +class TimeStretch(torch.jit.ScriptModule): + r"""Stretch stft in time without modifying pitch for a given rate. + + Args: + hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``) + n_freq (int, optional): number of filter banks from stft. (Default: ``201``) + fixed_rate (float): rate to speed up or slow down by. + If None is provided, rate must be passed to the forward method. (Default: ``None``) + """ + __constants__ = ['fixed_rate'] + + def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): + super(TimeStretch, self).__init__() + + self.fixed_rate = fixed_rate + + n_fft = (n_freq - 1) * 2 + hop_length = hop_length if hop_length is not None else n_fft // 2 + phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] + self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) + + def forward(self, complex_specgrams, overriding_rate=None): + # type: (Tensor, Optional[float]) -> Tensor + r""" + Args: + complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2) + overriding_rate (float or None): speed up to apply to this batch. + If no rate is passed, use ``self.fixed_rate`` + + Returns: + (Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2) + """ + assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)" + + if overriding_rate is None: + rate = self.fixed_rate + if rate is None: + raise ValueError("If no fixed_rate is specified" + ", must pass a valid rate to the forward method.") + else: + rate = overriding_rate + + if rate == 1.0: + return complex_specgrams + + shape = complex_specgrams.size() + complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) + complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance) + + return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) + + +class _AxisMasking(torch.nn.Module): + r""" + Apply masking to a spectrogram. + Args: + mask_param (int): Maximum possible length of the mask + axis: What dimension the mask is applied on + iid_masks (bool): Applies iid masks to each of the examples in the batch dimension + """ + __constants__ = ['mask_param', 'axis', 'iid_masks'] + + def __init__(self, mask_param, axis, iid_masks): + + super(_AxisMasking, self).__init__() + self.mask_param = mask_param + self.axis = axis + self.iid_masks = iid_masks + + def forward(self, specgram, mask_value=0.): + # type: (Tensor, float) -> Tensor + r""" + Args: + specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time) + + Returns: + torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time) + """ + + # if iid_masks flag marked and specgram has a batch dimension + if self.iid_masks and specgram.dim() == 4: + return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) + else: + shape = specgram.size() + specgram = specgram.reshape([-1] + list(shape[-2:])) + specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) + + return specgram.reshape(shape[:-2] + specgram.shape[-2:]) + + +class FrequencyMasking(_AxisMasking): + r""" + Apply masking to a spectrogram in the frequency domain. + Args: + freq_mask_param (int): maximum possible length of the mask. + Indices uniformly sampled from [0, freq_mask_param). + iid_masks (bool): weather to apply the same mask to all + the examples/channels in the batch. (Default: False) + """ + + def __init__(self, freq_mask_param, iid_masks=False): + super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks) + + +class TimeMasking(_AxisMasking): + r""" + Apply masking to a spectrogram in the time domain. + Args: + time_mask_param (int): maximum possible length of the mask. + Indices uniformly sampled from [0, time_mask_param). + iid_masks (bool): weather to apply the same mask to all + the examples/channels in the batch. Defaults to False. + """ + + def __init__(self, time_mask_param, iid_masks=False): + super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)