Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,28 @@ Functions to perform common audio operations.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: highpass_biquad

:hidden:`equalizer_biquad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: equalizer_biquad

:hidden:`mask_along_axis`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: mask_along_axis

:hidden:`mask_along_axis_iid`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: mask_along_axis_iid

:hidden:`compute_deltas`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: compute_deltas

:hidden:`detect_pitch_frequency`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: detect_pitch_frequency
35 changes: 35 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,38 @@ Transforms are common audio transforms. They can be chained together using :clas
.. autoclass:: Resample

.. automethod:: torchaudio._docs.Resample.forward

:hidden:`ComplexNorm`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ComplexNorm

.. automethod:: torchaudio._docs.ComplexNorm.forward

:hidden:`ComputeDeltas`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ComputeDeltas

.. automethod:: torchaudio._docs.ComputeDeltas.forward

:hidden:`TimeStretch`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: TimeStretch

.. automethod:: torchaudio._docs.TimeStretch.forward

:hidden:`FrequencyMasking`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: FrequencyMasking

.. automethod:: torchaudio._docs.FrequencyMasking.forward

:hidden:`TimeMasking`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: TimeMasking

.. automethod:: torchaudio._docs.TimeMasking.forward
7 changes: 3 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import torchaudio
import torchaudio.augmentations as A
import torchaudio.transforms as transforms
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
Expand Down Expand Up @@ -424,15 +423,15 @@ def test_scriptmodule_TimeStretch(self):
hop_length = 512
fixed_rate = 1.3
tensor = torch.rand((10, 2, n_freq, 10, 2))
_test_script_module(A.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
_test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)

def test_scriptmodule_FrequencyMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(A.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
_test_script_module(transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)

def test_scriptmodule_TimeMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(A.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
_test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)


if __name__ == '__main__':
Expand Down
128 changes: 0 additions & 128 deletions torchaudio/augmentations.py

This file was deleted.

123 changes: 122 additions & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
'MuLawEncoding',
'MuLawDecoding',
'Resample',
'ComplexNorm'
'ComplexNorm',
'TimeStretch',
'FrequencyMasking',
'TimeMasking',
]


Expand Down Expand Up @@ -408,3 +411,121 @@ def forward(self, specgram):
deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
"""
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)


class TimeStretch(torch.jit.ScriptModule):
r"""Stretch stft in time without modifying pitch for a given rate.

Args:
hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
"""
__constants__ = ['fixed_rate']

def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
super(TimeStretch, self).__init__()

self.fixed_rate = fixed_rate

n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)

def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``

Returns:
(Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2)
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)"

if overriding_rate is None:
rate = self.fixed_rate
if rate is None:
raise ValueError("If no fixed_rate is specified"
", must pass a valid rate to the forward method.")
else:
rate = overriding_rate

if rate == 1.0:
return complex_specgrams

shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance)

return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])


class _AxisMasking(torch.nn.Module):
r"""
Apply masking to a spectrogram.
Args:
mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
"""
__constants__ = ['mask_param', 'axis', 'iid_masks']

def __init__(self, mask_param, axis, iid_masks):

super(_AxisMasking, self).__init__()
self.mask_param = mask_param
self.axis = axis
self.iid_masks = iid_masks

def forward(self, specgram, mask_value=0.):
# type: (Tensor, float) -> Tensor
r"""
Args:
specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time)

Returns:
torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time)
"""

# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
else:
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)

return specgram.reshape(shape[:-2] + specgram.shape[-2:])


class FrequencyMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the frequency domain.
Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. (Default: False)
"""

def __init__(self, freq_mask_param, iid_masks=False):
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the time domain.
Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. Defaults to False.
"""

def __init__(self, time_mask_param, iid_masks=False):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)