diff --git a/test/test_functional.py b/test/test_functional.py index 5a5a169f86..811b505e73 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -223,8 +223,7 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad): @pytest.mark.parametrize('complex_specgrams', [ - torch.randn(1, 2, 1025, 400, 2), - torch.randn(1, 1025, 400, 2) + torch.randn(2, 1025, 400, 2) ]) @pytest.mark.parametrize('rate', [0.5, 1.01, 1.3]) @pytest.mark.parametrize('hop_length', [256]) @@ -277,5 +276,45 @@ def test_complex_norm(complex_tensor, power): assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5) +@pytest.mark.parametrize('specgram', [ + torch.randn(2, 1025, 400), + torch.randn(1, 201, 100) +]) +@pytest.mark.parametrize('mask_param', [100]) +@pytest.mark.parametrize('mask_value', [0., 30.]) +@pytest.mark.parametrize('axis', [1, 2]) +def test_mask_along_axis(specgram, mask_param, mask_value, axis): + + mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis) + + other_axis = 1 if axis == 2 else 2 + + masked_columns = (mask_specgram == mask_value).sum(other_axis) + num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum() + num_masked_columns /= mask_specgram.size(0) + + assert mask_specgram.size() == specgram.size() + assert num_masked_columns < mask_param + + +@pytest.mark.parametrize('specgrams', [ + torch.randn(4, 2, 1025, 400), +]) +@pytest.mark.parametrize('mask_param', [100]) +@pytest.mark.parametrize('mask_value', [0., 30.]) +@pytest.mark.parametrize('axis', [2, 3]) +def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis): + + mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) + + other_axis = 2 if axis == 3 else 3 + + masked_columns = (mask_specgrams == mask_value).sum(other_axis) + num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1) + + assert mask_specgrams.size() == specgrams.size() + assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() + + if __name__ == '__main__': unittest.main() diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py new file mode 100644 index 0000000000..52c6e2091b --- /dev/null +++ b/torchaudio/augmentations.py @@ -0,0 +1,130 @@ +import math +import torch + +from . import functional as F + +__all__ = [ + 'TimeStretch', + 'FrequencyMasking', + 'TimeMasking' +] + + +class TimeStretch(torch.jit.ScriptModule): + r"""Stretch stft in time without modifying pitch for a given rate. + + Args: + hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``) + n_freq (int, optional): number of filter banks from stft. (Default: ``201``) + fixed_rate (float): rate to speed up or slow down by. + If None is provided, rate must be passed to the forward method. (Default: ``None``) + """ + __constants__ = ['fixed_rate'] + + def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): + super(TimeStretch, self).__init__() + + n_fft = (n_freq - 1) * 2 + hop_length = hop_length if hop_length is not None else n_fft // 2 + self.fixed_rate = fixed_rate + phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] + + self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) + + @torch.jit.script_method + def forward(self, complex_specgrams, overriding_rate=None): + # type: (Tensor, Optional[float]) -> Tensor + r""" + Args: + complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2) + overriding_rate (float or None): speed up to apply to this batch. + If no rate is passed, use ``self.fixed_rate`` + + Returns: + (Tensor): Stretched complex spectrogram of dimension (*, channel, n_freq, ceil(time/rate), complex=2) + """ + assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)" + + if overriding_rate is None: + rate = self.fixed_rate + if rate is None: + raise ValueError("If no fixed_rate is specified" + ", must pass a valid rate to the forward method.") + else: + rate = overriding_rate + + if rate == 1.0: + return complex_specgrams + + shape = complex_specgrams.size() + complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) + complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance) + + return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) + + +class _AxisMasking(torch.jit.ScriptModule): + r""" + Apply masking to a spectrogram. + Args: + mask_param (int): Maximum possible length of the mask + axis: What dimension the mask is applied on + iid_masks (bool): Applies iid masks to each of the examples in the batch dimension + """ + __constants__ = ['mask_param', 'axis', 'iid_masks'] + + def __init__(self, mask_param, axis, iid_masks): + + super(_AxisMasking, self).__init__() + self.mask_param = mask_param + self.axis = axis + self.iid_masks = iid_masks + + @torch.jit.script_method + def forward(self, specgram, mask_value=0.): + # type: (Tensor, float) -> Tensor + r""" + Args: + specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time) + + Returns: + torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time) + """ + + # if iid_masks flag marked and specgram has a batch dimension + if self.iid_masks and specgram.dim() == 4: + return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) + else: + shape = specgram.size() + specgram = specgram.reshape([-1] + list(shape[-2:])) + specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) + + return specgram.reshape(shape[:-2] + specgram.shape[-2:]) + + +class FrequencyMasking(_AxisMasking): + r""" + Apply masking to a spectrogram in the frequency domain. + Args: + freq_mask_param (int): maximum possible length of the mask. + Indices uniformly sampled from [0, freq_mask_param). + iid_masks (bool): weather to apply the same mask to all + the examples/channels in the batch. (Default: False) + """ + + def __init__(self, freq_mask_param, iid_masks=False): + super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks) + + +class TimeMasking(_AxisMasking): + r""" + Apply masking to a spectrogram in the time domain. + Args: + time_mask_param (int): maximum possible length of the mask. + Indices uniformly sampled from [0, time_mask_param). + iid_masks (bool): weather to apply the same mask to all + the examples/channels in the batch. Defaults to False. + """ + + def __init__(self, time_mask_param, iid_masks=False): + super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 547738cf9b..c21690fb65 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -18,6 +18,8 @@ "lowpass_biquad", "highpass_biquad", "biquad", + 'mask_along_axis', + 'mask_along_axis_iid' ] @@ -228,8 +230,8 @@ def spectrogram( normalized (bool): Whether to normalize by magnitude after stft Returns: - torch.Tensor: Dimension (channel, freq, time), where channel - is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of + torch.Tensor: Dimension (channel, n_freq, time), where channel + is unchanged, n_freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frames). """ assert waveform.dim() == 2 @@ -397,7 +399,9 @@ def mu_law_decoding(x_mu, quantization_channels): return x +@torch.jit.script def complex_norm(complex_tensor, power=1.0): + # type: (Tensor, float) -> Tensor r"""Compute the norm of complex tensor input. Args: @@ -439,64 +443,59 @@ def magphase(complex_tensor, power=1.0): return mag, phase +@torch.jit.script def phase_vocoder(complex_specgrams, rate, phase_advance): + # type: (Tensor, float, Tensor) -> Tensor r"""Given a STFT tensor, speed up in time without modifying pitch by a factor of ``rate``. - Args: - complex_specgrams (torch.Tensor): Dimension of `(*, channel, freq, time, complex=2)` + complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)` rate (float): Speed-up factor phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension of (freq, 1) - Returns: - complex_specgrams_stretch (torch.Tensor): Dimension of `(*, channel, + complex_specgrams_stretch (torch.Tensor): Dimension of `(channel, freq, ceil(time/rate), complex=2)` - Example - >>> num_freqs, hop_length = 1025, 512 - >>> # (batch, channel, num_freqs, time, complex=2) - >>> complex_specgrams = torch.randn(16, 1, num_freqs, 300, 2) - >>> rate = 1.3 # Slow down by 30% + >>> freq, hop_length = 1025, 512 + >>> # (channel, freq, time, complex=2) + >>> complex_specgrams = torch.randn(2, freq, 300, 2) + >>> rate = 1.3 # Speed up by 30% >>> phase_advance = torch.linspace( - >>> 0, math.pi * hop_length, num_freqs)[..., None] + >>> 0, math.pi * hop_length, freq)[..., None] >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) >>> x.shape # with 231 == ceil(300 / 1.3) - torch.Size([16, 1, 1025, 231, 2]) + torch.Size([2, 1025, 231, 2]) """ - ndim = complex_specgrams.dim() - time_slice = [slice(None)] * (ndim - 2) - - time_steps = torch.arange( - 0, - complex_specgrams.size(-2), - rate, - device=complex_specgrams.device, - dtype=complex_specgrams.dtype, - ) + + time_steps = torch.arange(0, + complex_specgrams.size(-2), + rate, + device=complex_specgrams.device, + dtype=complex_specgrams.dtype) alphas = time_steps % 1.0 - phase_0 = angle(complex_specgrams[time_slice + [slice(1)]]) + phase_0 = angle(complex_specgrams[:, :, :1]) # Time Padding complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) - # (new_bins, num_freqs, 2) - complex_specgrams_0 = complex_specgrams[time_slice + [time_steps.long()]] - complex_specgrams_1 = complex_specgrams[time_slice + [(time_steps + 1).long()]] + # (new_bins, freq, 2) + complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()] + complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()] angle_0 = angle(complex_specgrams_0) angle_1 = angle(complex_specgrams_1) - norm_0 = torch.norm(complex_specgrams_0, dim=-1) - norm_1 = torch.norm(complex_specgrams_1, dim=-1) + norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1) + norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1) phase = angle_1 - angle_0 - phase_advance phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi)) # Compute Phase Accum phase = phase + phase_advance - phase = torch.cat([phase_0, phase[time_slice + [slice(-1)]]], dim=-1) + phase = torch.cat([phase_0, phase[:, :, :-1]], dim=-1) phase_acc = torch.cumsum(phase, -1) mag = alphas * norm_1 + (1 - alphas) * norm_0 @@ -655,6 +654,79 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) +@torch.jit.script +def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): + # type: (Tensor, int, float, int) -> Tensor + r""" + Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where + ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. + All examples will have the same mask interval. + + Args: + specgrams (Tensor): Real spectrograms (batch, channel, n_freq, time) + mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] + mask_value (float): Value to assign to the masked columns + axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) + + Returns: + torch.Tensor: Masked spectrograms of dimensions (batch, channel, n_freq, time) + """ + + if axis != 2 and axis != 3: + raise ValueError('Only Frequency and Time masking are supported') + + value = torch.rand(specgrams.shape[:2]) * mask_param + min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value) + + # Create broadcastable mask + mask_start = (min_value.long())[..., None, None].float() + mask_end = (min_value.long() + value.long())[..., None, None].float() + mask = torch.arange(0, specgrams.size(axis)).float() + + # Per batch example masking + specgrams = specgrams.transpose(axis, -1) + specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value) + specgrams = specgrams.transpose(axis, -1) + + return specgrams + + +@torch.jit.script +def mask_along_axis(specgram, mask_param, mask_value, axis): + # type: (Tensor, int, float, int) -> Tensor + r""" + Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where + ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. + All examples will have the same mask interval. + + Args: + specgram (Tensor): Real spectrogram (channel, n_freq, time) + mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] + mask_value (float): Value to assign to the masked columns + axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) + + Returns: + torch.Tensor: Masked spectrogram of dimensions (channel, n_freq, time) + """ + + value = torch.rand(1) * mask_param + min_value = torch.rand(1) * (specgram.size(axis) - value) + + mask_start = (min_value.long()).squeeze() + mask_end = (min_value.long() + value.long()).squeeze() + + assert mask_end - mask_start < mask_param + if axis == 1: + specgram[:, mask_start:mask_end] = mask_value + elif axis == 2: + specgram[:, :, mask_start:mask_end] = mask_value + else: + raise ValueError('Only Frequency and Time masking are supported') + + return specgram + + +@torch.jit.script def compute_deltas(specgram, win_length=5, mode="replicate"): # type: (Tensor, int, str) -> Tensor r"""Compute delta coefficients of a tensor, usually a spectrogram: diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 7362e7268d..4a61dc60d9 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -16,6 +16,7 @@ 'MuLawEncoding', 'MuLawDecoding', 'Resample', + 'ComplexNorm' ] @@ -367,6 +368,28 @@ def forward(self, waveform): raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) +class ComplexNorm(torch.jit.ScriptModule): + r"""Compute the norm of complex tensor input + Args: + power (float): Power of the norm. Defaults to `1.0`. + """ + __constants__ = ['power'] + + def __init__(self, power=1.0): + super(ComplexNorm, self).__init__() + self.power = power + + @torch.jit.script_method + def forward(self, complex_tensor): + r""" + Args: + complex_tensor (Tensor): Tensor shape of `(*, complex=2)` + Returns: + Tensor: norm of the input tensor, shape of `(*, )` + """ + return F.complex_norm(complex_tensor, self.power) + + class ComputeDeltas(torch.jit.ScriptModule): r"""Compute delta coefficients of a tensor, usually a spectrogram.