Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad):


@pytest.mark.parametrize('complex_specgrams', [
torch.randn(1, 2, 1025, 400, 2),
torch.randn(1, 1025, 400, 2)
torch.randn(2, 1025, 400, 2)
])
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('hop_length', [256])
Expand Down Expand Up @@ -277,5 +276,45 @@ def test_complex_norm(complex_tensor, power):
assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5)


@pytest.mark.parametrize('specgram', [
torch.randn(2, 1025, 400),
torch.randn(1, 201, 100)
])
@pytest.mark.parametrize('mask_param', [100])
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [1, 2])
def test_mask_along_axis(specgram, mask_param, mask_value, axis):

mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)

other_axis = 1 if axis == 2 else 2

masked_columns = (mask_specgram == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
num_masked_columns /= mask_specgram.size(0)

assert mask_specgram.size() == specgram.size()
assert num_masked_columns < mask_param


@pytest.mark.parametrize('specgrams', [
torch.randn(4, 2, 1025, 400),
])
@pytest.mark.parametrize('mask_param', [100])
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [2, 3])
def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis):

mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)

other_axis = 2 if axis == 3 else 3

masked_columns = (mask_specgrams == mask_value).sum(other_axis)
num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)

assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()


if __name__ == '__main__':
unittest.main()
130 changes: 130 additions & 0 deletions torchaudio/augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import math
import torch

from . import functional as F

__all__ = [
'TimeStretch',
'FrequencyMasking',
'TimeMasking'
]


class TimeStretch(torch.jit.ScriptModule):
r"""Stretch stft in time without modifying pitch for a given rate.

Args:
hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
"""
__constants__ = ['fixed_rate']

def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
super(TimeStretch, self).__init__()

n_fft = (n_freq - 1) * 2
hop_length = hop_length if hop_length is not None else n_fft // 2
self.fixed_rate = fixed_rate
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]

self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)

@torch.jit.script_method
def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``

Returns:
(Tensor): Stretched complex spectrogram of dimension (*, channel, n_freq, ceil(time/rate), complex=2)
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)"

if overriding_rate is None:
rate = self.fixed_rate
if rate is None:
raise ValueError("If no fixed_rate is specified"
", must pass a valid rate to the forward method.")
else:
rate = overriding_rate

if rate == 1.0:
return complex_specgrams

shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance)

return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])


class _AxisMasking(torch.jit.ScriptModule):
r"""
Apply masking to a spectrogram.
Args:
mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
"""
__constants__ = ['mask_param', 'axis', 'iid_masks']

def __init__(self, mask_param, axis, iid_masks):

super(_AxisMasking, self).__init__()
self.mask_param = mask_param
self.axis = axis
self.iid_masks = iid_masks

@torch.jit.script_method
def forward(self, specgram, mask_value=0.):
# type: (Tensor, float) -> Tensor
r"""
Args:
specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time)

Returns:
torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time)
"""

# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
else:
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)

return specgram.reshape(shape[:-2] + specgram.shape[-2:])


class FrequencyMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the frequency domain.
Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. (Default: False)
"""

def __init__(self, freq_mask_param, iid_masks=False):
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the time domain.
Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. Defaults to False.
"""

def __init__(self, time_mask_param, iid_masks=False):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
132 changes: 102 additions & 30 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"lowpass_biquad",
"highpass_biquad",
"biquad",
'mask_along_axis',
'mask_along_axis_iid'
]


Expand Down Expand Up @@ -228,8 +230,8 @@ def spectrogram(
normalized (bool): Whether to normalize by magnitude after stft

Returns:
torch.Tensor: Dimension (channel, freq, time), where channel
is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of
torch.Tensor: Dimension (channel, n_freq, time), where channel
is unchanged, n_freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frames).
"""
assert waveform.dim() == 2
Expand Down Expand Up @@ -397,7 +399,9 @@ def mu_law_decoding(x_mu, quantization_channels):
return x


@torch.jit.script
def complex_norm(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tensor
r"""Compute the norm of complex tensor input.

Args:
Expand Down Expand Up @@ -439,64 +443,59 @@ def magphase(complex_tensor, power=1.0):
return mag, phase


@torch.jit.script
def phase_vocoder(complex_specgrams, rate, phase_advance):
# type: (Tensor, float, Tensor) -> Tensor
r"""Given a STFT tensor, speed up in time without modifying pitch by a
factor of ``rate``.

Args:
complex_specgrams (torch.Tensor): Dimension of `(*, channel, freq, time, complex=2)`
complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)`
rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension
of (freq, 1)

Returns:
complex_specgrams_stretch (torch.Tensor): Dimension of `(*, channel,
complex_specgrams_stretch (torch.Tensor): Dimension of `(channel,
freq, ceil(time/rate), complex=2)`

Example
>>> num_freqs, hop_length = 1025, 512
>>> # (batch, channel, num_freqs, time, complex=2)
>>> complex_specgrams = torch.randn(16, 1, num_freqs, 300, 2)
>>> rate = 1.3 # Slow down by 30%
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, num_freqs)[..., None]
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([16, 1, 1025, 231, 2])
torch.Size([2, 1025, 231, 2])
"""
ndim = complex_specgrams.dim()
time_slice = [slice(None)] * (ndim - 2)

time_steps = torch.arange(
0,
complex_specgrams.size(-2),
rate,
device=complex_specgrams.device,
dtype=complex_specgrams.dtype,
)

time_steps = torch.arange(0,
complex_specgrams.size(-2),
rate,
device=complex_specgrams.device,
dtype=complex_specgrams.dtype)

alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[time_slice + [slice(1)]])
phase_0 = angle(complex_specgrams[:, :, :1])

# Time Padding
complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])

# (new_bins, num_freqs, 2)
complex_specgrams_0 = complex_specgrams[time_slice + [time_steps.long()]]
complex_specgrams_1 = complex_specgrams[time_slice + [(time_steps + 1).long()]]
# (new_bins, freq, 2)
complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()]
complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()]

angle_0 = angle(complex_specgrams_0)
angle_1 = angle(complex_specgrams_1)

norm_0 = torch.norm(complex_specgrams_0, dim=-1)
norm_1 = torch.norm(complex_specgrams_1, dim=-1)
norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1)
norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1)

phase = angle_1 - angle_0 - phase_advance
phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))

# Compute Phase Accum
phase = phase + phase_advance
phase = torch.cat([phase_0, phase[time_slice + [slice(-1)]]], dim=-1)
phase = torch.cat([phase_0, phase[:, :, :-1]], dim=-1)
phase_acc = torch.cumsum(phase, -1)

mag = alphas * norm_1 + (1 - alphas) * norm_0
Expand Down Expand Up @@ -655,6 +654,79 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)


@torch.jit.script
def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
All examples will have the same mask interval.

Args:
specgrams (Tensor): Real spectrograms (batch, channel, n_freq, time)
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)

Returns:
torch.Tensor: Masked spectrograms of dimensions (batch, channel, n_freq, time)
"""

if axis != 2 and axis != 3:
raise ValueError('Only Frequency and Time masking are supported')

value = torch.rand(specgrams.shape[:2]) * mask_param
min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value)

# Create broadcastable mask
mask_start = (min_value.long())[..., None, None].float()
mask_end = (min_value.long() + value.long())[..., None, None].float()
mask = torch.arange(0, specgrams.size(axis)).float()

# Per batch example masking
specgrams = specgrams.transpose(axis, -1)
specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
specgrams = specgrams.transpose(axis, -1)

return specgrams


@torch.jit.script
def mask_along_axis(specgram, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
All examples will have the same mask interval.

Args:
specgram (Tensor): Real spectrogram (channel, n_freq, time)
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)

Returns:
torch.Tensor: Masked spectrogram of dimensions (channel, n_freq, time)
"""

value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)

mask_start = (min_value.long()).squeeze()
mask_end = (min_value.long() + value.long()).squeeze()

assert mask_end - mask_start < mask_param
if axis == 1:
specgram[:, mask_start:mask_end] = mask_value
elif axis == 2:
specgram[:, :, mask_start:mask_end] = mask_value
else:
raise ValueError('Only Frequency and Time masking are supported')

return specgram


@torch.jit.script
def compute_deltas(specgram, win_length=5, mode="replicate"):
# type: (Tensor, int, str) -> Tensor
r"""Compute delta coefficients of a tensor, usually a spectrogram:
Expand Down
Loading