From 8db50c72313ead88e57c11fc7629342a5a26a859 Mon Sep 17 00:00:00 2001 From: Kiran Sanjeevan Date: Mon, 16 Sep 2019 15:28:46 -0700 Subject: [PATCH 1/9] TimeStretch and Masking --- test/test_functional.py | 3 +- torchaudio/augmentations.py | 177 ++++++++++++++++++++++++++++++++++++ torchaudio/functional.py | 56 +++++++++--- torchaudio/transforms.py | 67 ++++++++++++++ 4 files changed, 289 insertions(+), 14 deletions(-) create mode 100644 torchaudio/augmentations.py diff --git a/test/test_functional.py b/test/test_functional.py index 8f4f84942d..087902a751 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -197,8 +197,7 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad): @pytest.mark.parametrize('complex_specgrams', [ - torch.randn(1, 2, 1025, 400, 2), - torch.randn(1, 1025, 400, 2) + torch.randn(2, 1025, 400, 2) ]) @pytest.mark.parametrize('rate', [0.5, 1.01, 1.3]) @pytest.mark.parametrize('hop_length', [256]) diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py new file mode 100644 index 0000000000..4897607b66 --- /dev/null +++ b/torchaudio/augmentations.py @@ -0,0 +1,177 @@ +import math +import torch + +from . import functional as F + + +class TimeStretch(torch.jit.ScriptModule): + r"""Stretch stft in time without modifying pitch for a given rate. + + Args: + hop_length (int): Number audio of frames between STFT columns. + num_freqs (int, optional): number of filter banks from stft. + fixed_rate (float): rate to speed up or slow down by. + Defaults to None (in which case a rate must be + passed to the forward method per batch). + """ + __constants__ = ['fixed_rate'] + + def __init__(self, hop_length=200, num_freqs=201, fixed_rate=None): + super(TimeStretch, self).__init__() + + self.fixed_rate = fixed_rate + phase_advance = torch.linspace(0, math.pi * hop_length, num_freqs)[..., None] + + self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) + + @torch.jit.script_method + def forward(self, complex_specgrams, overriding_rate=None): + # type: (Tensor, Optional[float]) -> Tensor + r""" + Args: + complex_specgrams (Tensor): complex spectrogram + (*, channel, freq, time, complex=2) + overriding_rate (float or None): speed up to apply to this batch. + If no rate is passed, use self.fixed_rate + Returns: + (Tensor): (*, channel, num_freqs, ceil(time/rate), complex=2) + """ + if overriding_rate is None: + rate = self.fixed_rate + if rate is None: + raise ValueError("If no fixed_rate is specified" + ", must pass a valid rate to the forward method.") + else: + rate = overriding_rate + + if rate == 1.0: + return complex_specgrams + + shape = complex_specgrams.size() + complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) + complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance) + + return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) + + +@torch.jit.script +def mask_along_axis_iid(specgram, mask_param, mask_value, axis): + # type: (Tensor, int, float, int) -> Tensor + r""" + Apply a mask along ``axis``. Mask will be applied from ``[v_0, v_0 + v)``, where + ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. + All examples will have the same mask interval. + + Args: + specgram (Tensor): Real spectogram (batch, channel, num_freqs, time) + mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] + mask_value (float): Value to assign to the masked columns + axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) + """ + + if axis != 2 and axis != 3: + raise ValueError('Only Frequency and Time masking are supported') + + value = torch.rand(specgram.shape[:2]) * mask_param + min_value = torch.rand(specgram.shape[:2]) * (specgram.size(axis) - value) + + mask_start = (min_value.long()).unsqueeze(-1).float() + mask_end = (min_value.long() + value.long()).unsqueeze(-1).float() + + mask = torch.arange(0, specgram.size(axis)).repeat(specgram.size(0), specgram.size(1), 1).float() + + specgram = specgram.transpose(2, axis) + specgram[(mask >= mask_start) & (mask < mask_end)] = torch.tensor(mask_value) + specgram = specgram.transpose(2, axis) + + return specgram + + +@torch.jit.script +def mask_along_axis(specgram, mask_param, mask_value, axis): + # type: (Tensor, int, float, int) -> Tensor + r""" + Apply a mask along ``axis``. Mask will be applied from ``[v_0, v_0 + v)``, where + ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. + All examples will have the same mask interval. + + Args: + specgram (Tensor): Real spectogram (batch, channel, num_freqs, time) + mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] + mask_value (float): Value to assign to the masked columns + axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) + """ + + value = torch.rand(1) * mask_param + min_value = torch.rand(1) * (specgram.size(axis) - value) + + mask_start = (min_value.long()).squeeze() + mask_end = (min_value.long() + value.long()).squeeze() + + if axis == 1: + specgram[:, mask_start:mask_end] = mask_value + elif axis == 2: + specgram[:, :, mask_start:mask_end] = mask_value + else: + raise ValueError('Only Frequency and Time masking is supported') + + return specgram + + +class _AxisMasking(torch.jit.ScriptModule): + r""" + Apply masking to a spectrogram. + Args: + mask_param (int): Maximum possible length of the mask + axis: What dimension the mask is applied on + iid_masks (bool): Applies iid masks to each of the examples in the batch dimension + """ + __constants__ = ['mask_param', 'axis', 'iid_masks'] + + def __init__(self, mask_param, axis, iid_masks): + + super(_AxisMasking, self).__init__() + self.mask_param = mask_param + self.axis = axis + self.iid_masks = iid_masks + + @torch.jit.script_method + def forward(self, specgram, mask_value=0.): + # type: (Tensor, float) -> Tensor + # if iid_masks flag marked and specgram has a batch dimension + if self.iid_masks and specgram.dim() == 4: + return mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) + else: + shape = specgram.size() + specgram = specgram.reshape([-1] + list(shape[-2:])) + specgram = mask_along_axis(specgram, self.mask_param, mask_value, self.axis) + + return specgram.reshape(shape[:-2] + specgram.shape[-2:]) + + +class FrequencyMasking(_AxisMasking): + r""" + Apply masking to a spectrogram in the frequency domain. + Args: + freq_mask_param (int): maximum possible length of the mask. + Uniformly sampled from [0, freq_mask_param). + iid_masks (bool): weather to apply the same mask to all + the examples/channels in the batch. Defaults to False. + """ + + def __init__(self, freq_mask_param, iid_masks=False): + super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks) + + +class TimeMasking(_AxisMasking): + """ + Apply masking to a spectrogram in the time domain. + Args: + time_mask_param (int): maximum possible length of the mask. + Uniformly sampled from [0, time_mask_param). + iid_masks (bool): weather to apply the same mask to all + the examples/channels in the batch. Defaults to False. + """ + + def __init__(self, time_mask_param, iid_masks=False): + super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index e33cc702c7..e650f92fa8 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -358,7 +358,9 @@ def mu_law_decoding(x_mu, quantization_channels): return x +@torch.jit.script def complex_norm(complex_tensor, power=1.0): + # type: (Tensor, float) -> Tensor r"""Compute the norm of complex tensor input. Args: @@ -400,24 +402,26 @@ def magphase(complex_tensor, power=1.): return mag, phase +@torch.jit.script def phase_vocoder(complex_specgrams, rate, phase_advance): + # type: (Tensor, float, Tensor) -> Tensor r"""Given a STFT tensor, speed up in time without modifying pitch by a factor of ``rate``. Args: - complex_specgrams (torch.Tensor): Dimension of `(*, channel, freq, time, complex=2)` + complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)` rate (float): Speed-up factor phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension of (freq, 1) Returns: - complex_specgrams_stretch (torch.Tensor): Dimension of `(*, channel, + complex_specgrams_stretch (torch.Tensor): Dimension of `(channel, freq, ceil(time/rate), complex=2)` Example >>> num_freqs, hop_length = 1025, 512 - >>> # (batch, channel, num_freqs, time, complex=2) - >>> complex_specgrams = torch.randn(16, 1, num_freqs, 300, 2) + >>> # (channel, num_freqs, time, complex=2) + >>> complex_specgrams = torch.randn(1, num_freqs, 300, 2) >>> rate = 1.3 # Slow down by 30% >>> phase_advance = torch.linspace( >>> 0, math.pi * hop_length, num_freqs)[..., None] @@ -425,8 +429,6 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): >>> x.shape # with 231 == ceil(300 / 1.3) torch.Size([16, 1, 1025, 231, 2]) """ - ndim = complex_specgrams.dim() - time_slice = [slice(None)] * (ndim - 2) time_steps = torch.arange(0, complex_specgrams.size(-2), @@ -435,27 +437,27 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): dtype=complex_specgrams.dtype) alphas = time_steps % 1. - phase_0 = angle(complex_specgrams[time_slice + [slice(1)]]) + phase_0 = angle(complex_specgrams[:, :, :1]) # Time Padding complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) # (new_bins, num_freqs, 2) - complex_specgrams_0 = complex_specgrams[time_slice + [time_steps.long()]] - complex_specgrams_1 = complex_specgrams[time_slice + [(time_steps + 1).long()]] + complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()] + complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()] angle_0 = angle(complex_specgrams_0) angle_1 = angle(complex_specgrams_1) - norm_0 = torch.norm(complex_specgrams_0, dim=-1) - norm_1 = torch.norm(complex_specgrams_1, dim=-1) + norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1) + norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1) phase = angle_1 - angle_0 - phase_advance phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi)) # Compute Phase Accum phase = phase + phase_advance - phase = torch.cat([phase_0, phase[time_slice + [slice(-1)]]], dim=-1) + phase = torch.cat([phase_0, phase[:, :, :-1]], dim=-1) phase_acc = torch.cumsum(phase, -1) mag = alphas * norm_1 + (1 - alphas) * norm_0 @@ -466,3 +468,33 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) return complex_specgrams_stretch + + +@torch.jit.script +def stft(waveform, pad, window, n_fft, hop_length, win_length): + # type: (Tensor, int, Tensor, int, int, int) -> Tensor + r"""Create a spectrogram from a raw audio signal. + + Args: + waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + pad (int): Two sided padding of signal + window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window + n_fft (int): Size of FFT + hop_length (int): Length of hop between STFT windows + win_length (int): Window size + + Returns: + torch.Tensor: Dimension (channel, freq, time), where channel + is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of + Fourier bins, and time is the number of window hops (n_frames). + """ + assert waveform.dim() == 2 + + if pad > 0: + # TODO add "with torch.no_grad():" back when JIT supports it + waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") + + # default values are consistent with librosa.core.spectrum._spectrogram + spec_f = _stft(waveform, n_fft, hop_length, win_length, window, + True, 'reflect', False, True) + return spec_f diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 38e703b2b6..a00453421b 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -365,3 +365,70 @@ def forward(self, waveform): return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) + + +class STFT(torch.jit.ScriptModule): + r"""Create a complex stft from a audio signal + + Args: + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins + win_length (int): Window size. (Default: ``n_fft``) + hop_length (int, optional): Length of hop between STFT windows. ( + Default: ``win_length // 2``) + pad (int): Two sided padding of signal. (Default: ``0``) + window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) + """ + __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad'] + + def __init__(self, n_fft=400, win_length=None, hop_length=None, + pad=0, window_fn=torch.hann_window, wkwargs=None): + super(STFT, self).__init__() + self.n_fft = n_fft + # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 + # number of frequecies due to onesided=True in torch.stft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) + self.window = torch.jit.Attribute(window, torch.Tensor) + self.pad = pad + + @torch.jit.script_method + def forward(self, waveform): + r""" + Args: + waveform (torch.Tensor): Tensor of audio of dimension (*, channel, time) + + Returns: + torch.Tensor: Dimension (*, channel, freq, time, complex=2), where channel + is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of + Fourier bins, and time is the number of window hops (n_frames). + """ + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + complex_specgrams = F.stft(waveform, self.pad, self.window, self.n_fft, self.hop_length, self.win_length) + + return complex_specgrams.reshape(shape[:-1] + complex_specgrams.shape[-3:]) + + +class ComplexNorm(torch.jit.ScriptModule): + """Compute the norm of complex tensor input + Args: + power (float): Power of the norm. Defaults to `1.0`. + """ + __constants__ = ['power'] + + def __init__(self, power=1.0): + super(ComplexNorm, self).__init__() + self.power = power + + @torch.jit.script_method + def forward(self, complex_tensor): + """ + Args: + complex_tensor (Tensor): Tensor shape of `(*, complex=2)` + Returns: + Tensor: norm of the input tensor, shape of `(*, )` + """ + return F.complex_norm(complex_tensor, self.power) From a47dc283538af90fa9fa5ca27ce8be9d86d74779 Mon Sep 17 00:00:00 2001 From: Kiran Sanjeevan Date: Mon, 16 Sep 2019 15:50:53 -0700 Subject: [PATCH 2/9] Some refactoring --- torchaudio/augmentations.py | 74 ++++--------------------------------- torchaudio/functional.py | 64 ++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 66 deletions(-) diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py index 4897607b66..e4b7ba8fc7 100644 --- a/torchaudio/augmentations.py +++ b/torchaudio/augmentations.py @@ -3,6 +3,12 @@ from . import functional as F +__all__ = [ + 'TimeStretch', + 'FrequencyMasking', + 'TimeMasking' +] + class TimeStretch(torch.jit.ScriptModule): r"""Stretch stft in time without modifying pitch for a given rate. @@ -54,70 +60,6 @@ def forward(self, complex_specgrams, overriding_rate=None): return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) -@torch.jit.script -def mask_along_axis_iid(specgram, mask_param, mask_value, axis): - # type: (Tensor, int, float, int) -> Tensor - r""" - Apply a mask along ``axis``. Mask will be applied from ``[v_0, v_0 + v)``, where - ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. - All examples will have the same mask interval. - - Args: - specgram (Tensor): Real spectogram (batch, channel, num_freqs, time) - mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] - mask_value (float): Value to assign to the masked columns - axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) - """ - - if axis != 2 and axis != 3: - raise ValueError('Only Frequency and Time masking are supported') - - value = torch.rand(specgram.shape[:2]) * mask_param - min_value = torch.rand(specgram.shape[:2]) * (specgram.size(axis) - value) - - mask_start = (min_value.long()).unsqueeze(-1).float() - mask_end = (min_value.long() + value.long()).unsqueeze(-1).float() - - mask = torch.arange(0, specgram.size(axis)).repeat(specgram.size(0), specgram.size(1), 1).float() - - specgram = specgram.transpose(2, axis) - specgram[(mask >= mask_start) & (mask < mask_end)] = torch.tensor(mask_value) - specgram = specgram.transpose(2, axis) - - return specgram - - -@torch.jit.script -def mask_along_axis(specgram, mask_param, mask_value, axis): - # type: (Tensor, int, float, int) -> Tensor - r""" - Apply a mask along ``axis``. Mask will be applied from ``[v_0, v_0 + v)``, where - ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. - All examples will have the same mask interval. - - Args: - specgram (Tensor): Real spectogram (batch, channel, num_freqs, time) - mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] - mask_value (float): Value to assign to the masked columns - axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) - """ - - value = torch.rand(1) * mask_param - min_value = torch.rand(1) * (specgram.size(axis) - value) - - mask_start = (min_value.long()).squeeze() - mask_end = (min_value.long() + value.long()).squeeze() - - if axis == 1: - specgram[:, mask_start:mask_end] = mask_value - elif axis == 2: - specgram[:, :, mask_start:mask_end] = mask_value - else: - raise ValueError('Only Frequency and Time masking is supported') - - return specgram - - class _AxisMasking(torch.jit.ScriptModule): r""" Apply masking to a spectrogram. @@ -140,11 +82,11 @@ def forward(self, specgram, mask_value=0.): # type: (Tensor, float) -> Tensor # if iid_masks flag marked and specgram has a batch dimension if self.iid_masks and specgram.dim() == 4: - return mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) + return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) else: shape = specgram.size() specgram = specgram.reshape([-1] + list(shape[-2:])) - specgram = mask_along_axis(specgram, self.mask_param, mask_value, self.axis) + specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) return specgram.reshape(shape[:-2] + specgram.shape[-2:]) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index e650f92fa8..0d5493db88 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -498,3 +498,67 @@ def stft(waveform, pad, window, n_fft, hop_length, win_length): spec_f = _stft(waveform, n_fft, hop_length, win_length, window, True, 'reflect', False, True) return spec_f + + +@torch.jit.script +def mask_along_axis_iid(specgram, mask_param, mask_value, axis): + # type: (Tensor, int, float, int) -> Tensor + r""" + Apply a mask along ``axis``. Mask will be applied from ``[v_0, v_0 + v)``, where + ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. + All examples will have the same mask interval. + + Args: + specgram (Tensor): Real spectogram (batch, channel, num_freqs, time) + mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] + mask_value (float): Value to assign to the masked columns + axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) + """ + + if axis != 2 and axis != 3: + raise ValueError('Only Frequency and Time masking are supported') + + value = torch.rand(specgram.shape[:2]) * mask_param + min_value = torch.rand(specgram.shape[:2]) * (specgram.size(axis) - value) + + mask_start = (min_value.long()).unsqueeze(-1).float() + mask_end = (min_value.long() + value.long()).unsqueeze(-1).float() + + mask = torch.arange(0, specgram.size(axis)).repeat(specgram.size(0), specgram.size(1), 1).float() + + specgram = specgram.transpose(2, axis) + specgram[(mask >= mask_start) & (mask < mask_end)] = torch.tensor(mask_value) + specgram = specgram.transpose(2, axis) + + return specgram + + +@torch.jit.script +def mask_along_axis(specgram, mask_param, mask_value, axis): + # type: (Tensor, int, float, int) -> Tensor + r""" + Apply a mask along ``axis``. Mask will be applied from ``[v_0, v_0 + v)``, where + ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. + All examples will have the same mask interval. + + Args: + specgram (Tensor): Real spectogram (batch, channel, num_freqs, time) + mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] + mask_value (float): Value to assign to the masked columns + axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) + """ + + value = torch.rand(1) * mask_param + min_value = torch.rand(1) * (specgram.size(axis) - value) + + mask_start = (min_value.long()).squeeze() + mask_end = (min_value.long() + value.long()).squeeze() + + if axis == 1: + specgram[:, mask_start:mask_end] = mask_value + elif axis == 2: + specgram[:, :, mask_start:mask_end] = mask_value + else: + raise ValueError('Only Frequency and Time masking is supported') + + return specgram From 0a5b106b77e03a36816b24f2fa1cb26ea10fcc5b Mon Sep 17 00:00:00 2001 From: Kiran Sanjeevan Date: Tue, 17 Sep 2019 08:05:33 -0700 Subject: [PATCH 3/9] Fixed masking value for iid --- torchaudio/augmentations.py | 10 ++++++++++ torchaudio/functional.py | 21 ++++++++++++++------- torchaudio/transforms.py | 2 ++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py index e4b7ba8fc7..b87d1c5061 100644 --- a/torchaudio/augmentations.py +++ b/torchaudio/augmentations.py @@ -80,6 +80,16 @@ def __init__(self, mask_param, axis, iid_masks): @torch.jit.script_method def forward(self, specgram, mask_value=0.): # type: (Tensor, float) -> Tensor + r""" + Args: + specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time) + + Returns: + torch.Tensor: Dimension (channel, freq, time), where channel + is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of + Fourier bins, and time is the number of window hops (n_frames). + """ + # if iid_masks flag marked and specgram has a batch dimension if self.iid_masks and specgram.dim() == 4: return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 0d5493db88..46e309ce65 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -513,6 +513,9 @@ def mask_along_axis_iid(specgram, mask_param, mask_value, axis): mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) + + Returns: + torch.Tensor: Masked scpectograms of dimensions (batch, channel, num_freqs, time) """ if axis != 2 and axis != 3: @@ -521,14 +524,15 @@ def mask_along_axis_iid(specgram, mask_param, mask_value, axis): value = torch.rand(specgram.shape[:2]) * mask_param min_value = torch.rand(specgram.shape[:2]) * (specgram.size(axis) - value) - mask_start = (min_value.long()).unsqueeze(-1).float() - mask_end = (min_value.long() + value.long()).unsqueeze(-1).float() - - mask = torch.arange(0, specgram.size(axis)).repeat(specgram.size(0), specgram.size(1), 1).float() + # Create broadcastable mask + mask_start = (min_value.long())[..., None, None].float() + mask_end = (min_value.long() + value.long())[..., None, None].float() + mask = torch.arange(0, specgram.size(axis)).float() - specgram = specgram.transpose(2, axis) - specgram[(mask >= mask_start) & (mask < mask_end)] = torch.tensor(mask_value) - specgram = specgram.transpose(2, axis) + # Per batch example masking + specgram = specgram.transpose(axis, -1) + specgram.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value) + specgram = specgram.transpose(axis, -1) return specgram @@ -546,6 +550,9 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) + + Returns: + torch.Tensor: Masked scpectogram of dimensions (channel, num_freqs, time) """ value = torch.rand(1) * mask_param diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index a00453421b..0a4938d2e0 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -16,6 +16,8 @@ 'MuLawEncoding', 'MuLawDecoding', 'Resample', + 'STFT', + 'ComplexNorm' ] From d57822576eeadb547b3448227b716da844bc2734 Mon Sep 17 00:00:00 2001 From: Kiran Sanjeevan Date: Tue, 17 Sep 2019 08:41:08 -0700 Subject: [PATCH 4/9] Doc stuff and naming --- torchaudio/augmentations.py | 29 +++++++++++++-------------- torchaudio/functional.py | 40 ++++++++++++++++++------------------- 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py index b87d1c5061..c7663d39a3 100644 --- a/torchaudio/augmentations.py +++ b/torchaudio/augmentations.py @@ -14,17 +14,18 @@ class TimeStretch(torch.jit.ScriptModule): r"""Stretch stft in time without modifying pitch for a given rate. Args: - hop_length (int): Number audio of frames between STFT columns. - num_freqs (int, optional): number of filter banks from stft. + hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``) + num_freqs (int, optional): number of filter banks from stft. (Default: ``201``) fixed_rate (float): rate to speed up or slow down by. - Defaults to None (in which case a rate must be - passed to the forward method per batch). + If None is provided, rate must be passed to the forward method. (Default: ``None``) """ __constants__ = ['fixed_rate'] - def __init__(self, hop_length=200, num_freqs=201, fixed_rate=None): + def __init__(self, hop_length=None, num_freqs=201, fixed_rate=None): super(TimeStretch, self).__init__() + n_fft = (num_freqs - 1) * 2 + hop_length = hop_length if hop_length is not None else n_fft // 2 self.fixed_rate = fixed_rate phase_advance = torch.linspace(0, math.pi * hop_length, num_freqs)[..., None] @@ -35,12 +36,12 @@ def forward(self, complex_specgrams, overriding_rate=None): # type: (Tensor, Optional[float]) -> Tensor r""" Args: - complex_specgrams (Tensor): complex spectrogram - (*, channel, freq, time, complex=2) + complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2) overriding_rate (float or None): speed up to apply to this batch. - If no rate is passed, use self.fixed_rate + If no rate is passed, use ``self.fixed_rate`` + Returns: - (Tensor): (*, channel, num_freqs, ceil(time/rate), complex=2) + (Tensor): Stretched complex spectrogram of dimension (*, channel, num_freqs, ceil(time/rate), complex=2) """ if overriding_rate is None: rate = self.fixed_rate @@ -85,9 +86,7 @@ def forward(self, specgram, mask_value=0.): specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time) Returns: - torch.Tensor: Dimension (channel, freq, time), where channel - is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of - Fourier bins, and time is the number of window hops (n_frames). + torch.Tensor: Masked scpectogram of dimensions (*, channel, freq, time) """ # if iid_masks flag marked and specgram has a batch dimension @@ -106,9 +105,9 @@ class FrequencyMasking(_AxisMasking): Apply masking to a spectrogram in the frequency domain. Args: freq_mask_param (int): maximum possible length of the mask. - Uniformly sampled from [0, freq_mask_param). + Indices uniformly sampled from [0, freq_mask_param). iid_masks (bool): weather to apply the same mask to all - the examples/channels in the batch. Defaults to False. + the examples/channels in the batch. (Default: False) """ def __init__(self, freq_mask_param, iid_masks=False): @@ -120,7 +119,7 @@ class TimeMasking(_AxisMasking): Apply masking to a spectrogram in the time domain. Args: time_mask_param (int): maximum possible length of the mask. - Uniformly sampled from [0, time_mask_param). + Indices uniformly sampled from [0, time_mask_param). iid_masks (bool): weather to apply the same mask to all the examples/channels in the batch. Defaults to False. """ diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 46e309ce65..9f559ddc51 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -419,12 +419,12 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): freq, ceil(time/rate), complex=2)` Example - >>> num_freqs, hop_length = 1025, 512 - >>> # (channel, num_freqs, time, complex=2) - >>> complex_specgrams = torch.randn(1, num_freqs, 300, 2) + >>> freq, hop_length = 1025, 512 + >>> # (channel, freq, time, complex=2) + >>> complex_specgrams = torch.randn(1, freq, 300, 2) >>> rate = 1.3 # Slow down by 30% >>> phase_advance = torch.linspace( - >>> 0, math.pi * hop_length, num_freqs)[..., None] + >>> 0, math.pi * hop_length, freq)[..., None] >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) >>> x.shape # with 231 == ceil(300 / 1.3) torch.Size([16, 1, 1025, 231, 2]) @@ -442,7 +442,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): # Time Padding complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) - # (new_bins, num_freqs, 2) + # (new_bins, freq, 2) complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()] complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()] @@ -501,58 +501,58 @@ def stft(waveform, pad, window, n_fft, hop_length, win_length): @torch.jit.script -def mask_along_axis_iid(specgram, mask_param, mask_value, axis): +def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): # type: (Tensor, int, float, int) -> Tensor r""" - Apply a mask along ``axis``. Mask will be applied from ``[v_0, v_0 + v)``, where + Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. All examples will have the same mask interval. Args: - specgram (Tensor): Real spectogram (batch, channel, num_freqs, time) + specgrams (Tensor): Real spectograms (batch, channel, freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) Returns: - torch.Tensor: Masked scpectograms of dimensions (batch, channel, num_freqs, time) + torch.Tensor: Masked scpectograms of dimensions (batch, channel, freq, time) """ if axis != 2 and axis != 3: raise ValueError('Only Frequency and Time masking are supported') - value = torch.rand(specgram.shape[:2]) * mask_param - min_value = torch.rand(specgram.shape[:2]) * (specgram.size(axis) - value) + value = torch.rand(specgrams.shape[:2]) * mask_param + min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value) # Create broadcastable mask mask_start = (min_value.long())[..., None, None].float() mask_end = (min_value.long() + value.long())[..., None, None].float() - mask = torch.arange(0, specgram.size(axis)).float() + mask = torch.arange(0, specgrams.size(axis)).float() # Per batch example masking - specgram = specgram.transpose(axis, -1) - specgram.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value) - specgram = specgram.transpose(axis, -1) + specgrams = specgrams.transpose(axis, -1) + specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value) + specgrams = specgrams.transpose(axis, -1) - return specgram + return specgrams @torch.jit.script def mask_along_axis(specgram, mask_param, mask_value, axis): # type: (Tensor, int, float, int) -> Tensor r""" - Apply a mask along ``axis``. Mask will be applied from ``[v_0, v_0 + v)``, where + Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. All examples will have the same mask interval. Args: - specgram (Tensor): Real spectogram (batch, channel, num_freqs, time) + specgram (Tensor): Real spectogram (channel, freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) Returns: - torch.Tensor: Masked scpectogram of dimensions (channel, num_freqs, time) + torch.Tensor: Masked scpectogram of dimensions (channel, freq, time) """ value = torch.rand(1) * mask_param @@ -566,6 +566,6 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): elif axis == 2: specgram[:, :, mask_start:mask_end] = mask_value else: - raise ValueError('Only Frequency and Time masking is supported') + raise ValueError('Only Frequency and Time masking are supported') return specgram From 4da27718c64a325c7085ac132930d5a81f69cdff Mon Sep 17 00:00:00 2001 From: Kiran Sanjeevan Date: Tue, 17 Sep 2019 09:49:59 -0700 Subject: [PATCH 5/9] Typos --- torchaudio/augmentations.py | 2 +- torchaudio/functional.py | 6 +++--- torchaudio/transforms.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py index c7663d39a3..e9f6a01c66 100644 --- a/torchaudio/augmentations.py +++ b/torchaudio/augmentations.py @@ -115,7 +115,7 @@ def __init__(self, freq_mask_param, iid_masks=False): class TimeMasking(_AxisMasking): - """ + r""" Apply masking to a spectrogram in the time domain. Args: time_mask_param (int): maximum possible length of the mask. diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 9f559ddc51..26ee23ff81 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -421,13 +421,13 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): Example >>> freq, hop_length = 1025, 512 >>> # (channel, freq, time, complex=2) - >>> complex_specgrams = torch.randn(1, freq, 300, 2) - >>> rate = 1.3 # Slow down by 30% + >>> complex_specgrams = torch.randn(2, freq, 300, 2) + >>> rate = 1.3 # Speed up by 30% >>> phase_advance = torch.linspace( >>> 0, math.pi * hop_length, freq)[..., None] >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) >>> x.shape # with 231 == ceil(300 / 1.3) - torch.Size([16, 1, 1025, 231, 2]) + torch.Size([2, 1025, 231, 2]) """ time_steps = torch.arange(0, diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 0a4938d2e0..f60cb69807 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -415,7 +415,7 @@ def forward(self, waveform): class ComplexNorm(torch.jit.ScriptModule): - """Compute the norm of complex tensor input + r"""Compute the norm of complex tensor input Args: power (float): Power of the norm. Defaults to `1.0`. """ @@ -427,7 +427,7 @@ def __init__(self, power=1.0): @torch.jit.script_method def forward(self, complex_tensor): - """ + r""" Args: complex_tensor (Tensor): Tensor shape of `(*, complex=2)` Returns: From 45fb34dd7a738cf25c737bd7e47a9d482d385e1d Mon Sep 17 00:00:00 2001 From: Kiran Sanjeevan Date: Wed, 18 Sep 2019 14:23:31 -0700 Subject: [PATCH 6/9] + mask functional tests, - complex stft stuff --- test/test_functional.py | 42 +++++++++++++++++++++++++ torchaudio/augmentations.py | 12 ++++--- torchaudio/functional.py | 63 +++++++++++-------------------------- torchaudio/transforms.py | 46 --------------------------- 4 files changed, 67 insertions(+), 96 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 087902a751..a9c2d22a15 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -250,5 +250,47 @@ def test_complex_norm(complex_tensor, power): assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5) +@pytest.mark.parametrize('specgram', [ + torch.randn(2, 1025, 400), + torch.randn(1, 201, 100) +]) +@pytest.mark.parametrize('mask_param', [100]) +@pytest.mark.parametrize('mask_value', [0., 30.]) +@pytest.mark.parametrize('axis', [1, 2]) +def test_mask_along_axis(specgram, mask_param, mask_value, axis): + + mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis) + + other_axis = 1 if axis == 2 else 2 + + masked_columns = (mask_specgram == mask_value).sum(other_axis) + num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum() + num_masked_columns /= mask_specgram.size(0) + + assert mask_specgram.size() == specgram.size() + assert num_masked_columns < mask_param + + +@pytest.mark.parametrize('specgrams', [ + torch.randn(4, 2, 1025, 400), +]) +@pytest.mark.parametrize('mask_param', [100]) +@pytest.mark.parametrize('mask_value', [0., 30.]) +@pytest.mark.parametrize('axis', [2, 3]) +def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis): + + mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) + + torch.save(mask_specgrams, 'ex.pth') + + other_axis = 2 if axis == 3 else 3 + + masked_columns = (mask_specgrams == mask_value).sum(other_axis) + num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1) + + assert mask_specgrams.size() == specgrams.size() + assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() + + if __name__ == '__main__': unittest.main() diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py index e9f6a01c66..346c4130f2 100644 --- a/torchaudio/augmentations.py +++ b/torchaudio/augmentations.py @@ -15,19 +15,19 @@ class TimeStretch(torch.jit.ScriptModule): Args: hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``) - num_freqs (int, optional): number of filter banks from stft. (Default: ``201``) + n_freq (int, optional): number of filter banks from stft. (Default: ``201``) fixed_rate (float): rate to speed up or slow down by. If None is provided, rate must be passed to the forward method. (Default: ``None``) """ __constants__ = ['fixed_rate'] - def __init__(self, hop_length=None, num_freqs=201, fixed_rate=None): + def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): super(TimeStretch, self).__init__() - n_fft = (num_freqs - 1) * 2 + n_fft = (n_freq - 1) * 2 hop_length = hop_length if hop_length is not None else n_fft // 2 self.fixed_rate = fixed_rate - phase_advance = torch.linspace(0, math.pi * hop_length, num_freqs)[..., None] + phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) @@ -41,8 +41,10 @@ def forward(self, complex_specgrams, overriding_rate=None): If no rate is passed, use ``self.fixed_rate`` Returns: - (Tensor): Stretched complex spectrogram of dimension (*, channel, num_freqs, ceil(time/rate), complex=2) + (Tensor): Stretched complex spectrogram of dimension (*, channel, n_freq, ceil(time/rate), complex=2) """ + assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)" + if overriding_rate is None: rate = self.fixed_rate if rate is None: diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 26ee23ff81..1a95f69066 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -15,6 +15,8 @@ 'angle', 'magphase', 'phase_vocoder', + 'mask_along_axis', + 'mask_along_axis_iid' ] @@ -190,8 +192,8 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor normalized (bool): Whether to normalize by magnitude after stft Returns: - torch.Tensor: Dimension (channel, freq, time), where channel - is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of + torch.Tensor: Dimension (channel, n_freq, time), where channel + is unchanged, n_freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frames). """ assert waveform.dim() == 2 @@ -409,22 +411,22 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): factor of ``rate``. Args: - complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)` + complex_specgrams (torch.Tensor): Dimension of `(channel, n_freq, time, complex=2)` rate (float): Speed-up factor phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension - of (freq, 1) + of (n_freq, 1) Returns: complex_specgrams_stretch (torch.Tensor): Dimension of `(channel, - freq, ceil(time/rate), complex=2)` + n_freq, ceil(time/rate), complex=2)` - Example - >>> freq, hop_length = 1025, 512 - >>> # (channel, freq, time, complex=2) - >>> complex_specgrams = torch.randn(2, freq, 300, 2) + Example: + >>> n_freq, hop_length = 1025, 512 + >>> # (channel, n_freq, time, complex=2) + >>> complex_specgrams = torch.randn(2, n_freq, 300, 2) >>> rate = 1.3 # Speed up by 30% >>> phase_advance = torch.linspace( - >>> 0, math.pi * hop_length, freq)[..., None] + >>> 0, math.pi * hop_length, n_freq)[..., None] >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) >>> x.shape # with 231 == ceil(300 / 1.3) torch.Size([2, 1025, 231, 2]) @@ -442,7 +444,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): # Time Padding complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) - # (new_bins, freq, 2) + # (new_bins, n_freq, 2) complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()] complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()] @@ -470,36 +472,6 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): return complex_specgrams_stretch -@torch.jit.script -def stft(waveform, pad, window, n_fft, hop_length, win_length): - # type: (Tensor, int, Tensor, int, int, int) -> Tensor - r"""Create a spectrogram from a raw audio signal. - - Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) - pad (int): Two sided padding of signal - window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window - n_fft (int): Size of FFT - hop_length (int): Length of hop between STFT windows - win_length (int): Window size - - Returns: - torch.Tensor: Dimension (channel, freq, time), where channel - is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of - Fourier bins, and time is the number of window hops (n_frames). - """ - assert waveform.dim() == 2 - - if pad > 0: - # TODO add "with torch.no_grad():" back when JIT supports it - waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") - - # default values are consistent with librosa.core.spectrum._spectrogram - spec_f = _stft(waveform, n_fft, hop_length, win_length, window, - True, 'reflect', False, True) - return spec_f - - @torch.jit.script def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): # type: (Tensor, int, float, int) -> Tensor @@ -509,13 +481,13 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): All examples will have the same mask interval. Args: - specgrams (Tensor): Real spectograms (batch, channel, freq, time) + specgrams (Tensor): Real spectograms (batch, channel, n_freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) Returns: - torch.Tensor: Masked scpectograms of dimensions (batch, channel, freq, time) + torch.Tensor: Masked scpectograms of dimensions (batch, channel, n_freq, time) """ if axis != 2 and axis != 3: @@ -546,13 +518,13 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): All examples will have the same mask interval. Args: - specgram (Tensor): Real spectogram (channel, freq, time) + specgram (Tensor): Real spectogram (channel, n_freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) Returns: - torch.Tensor: Masked scpectogram of dimensions (channel, freq, time) + torch.Tensor: Masked scpectogram of dimensions (channel, n_freq, time) """ value = torch.rand(1) * mask_param @@ -561,6 +533,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): mask_start = (min_value.long()).squeeze() mask_end = (min_value.long() + value.long()).squeeze() + assert mask_end - mask_start < mask_param if axis == 1: specgram[:, mask_start:mask_end] = mask_value elif axis == 2: diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index f60cb69807..18fd4e7a99 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -16,7 +16,6 @@ 'MuLawEncoding', 'MuLawDecoding', 'Resample', - 'STFT', 'ComplexNorm' ] @@ -369,51 +368,6 @@ def forward(self, waveform): raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) -class STFT(torch.jit.ScriptModule): - r"""Create a complex stft from a audio signal - - Args: - n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins - win_length (int): Window size. (Default: ``n_fft``) - hop_length (int, optional): Length of hop between STFT windows. ( - Default: ``win_length // 2``) - pad (int): Two sided padding of signal. (Default: ``0``) - window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor - that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) - wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) - """ - __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad'] - - def __init__(self, n_fft=400, win_length=None, hop_length=None, - pad=0, window_fn=torch.hann_window, wkwargs=None): - super(STFT, self).__init__() - self.n_fft = n_fft - # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 - # number of frequecies due to onesided=True in torch.stft - self.win_length = win_length if win_length is not None else n_fft - self.hop_length = hop_length if hop_length is not None else self.win_length // 2 - window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) - self.window = torch.jit.Attribute(window, torch.Tensor) - self.pad = pad - - @torch.jit.script_method - def forward(self, waveform): - r""" - Args: - waveform (torch.Tensor): Tensor of audio of dimension (*, channel, time) - - Returns: - torch.Tensor: Dimension (*, channel, freq, time, complex=2), where channel - is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of - Fourier bins, and time is the number of window hops (n_frames). - """ - shape = waveform.size() - waveform = waveform.reshape(-1, shape[-1]) - complex_specgrams = F.stft(waveform, self.pad, self.window, self.n_fft, self.hop_length, self.win_length) - - return complex_specgrams.reshape(shape[:-1] + complex_specgrams.shape[-3:]) - - class ComplexNorm(torch.jit.ScriptModule): r"""Compute the norm of complex tensor input Args: From 07c6259c656d70df674f284017e09a79a4d3fe0f Mon Sep 17 00:00:00 2001 From: Vincent QB Date: Thu, 19 Sep 2019 16:51:35 -0400 Subject: [PATCH 7/9] typo --- torchaudio/augmentations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/augmentations.py b/torchaudio/augmentations.py index 346c4130f2..52c6e2091b 100644 --- a/torchaudio/augmentations.py +++ b/torchaudio/augmentations.py @@ -88,7 +88,7 @@ def forward(self, specgram, mask_value=0.): specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time) Returns: - torch.Tensor: Masked scpectogram of dimensions (*, channel, freq, time) + torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time) """ # if iid_masks flag marked and specgram has a batch dimension From d7b0a189cbc98f629bbe085d170f2e224b595388 Mon Sep 17 00:00:00 2001 From: Vincent QB Date: Thu, 19 Sep 2019 16:54:16 -0400 Subject: [PATCH 8/9] typos --- torchaudio/functional.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 7709b9d856..02ab2828c7 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -662,13 +662,13 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): All examples will have the same mask interval. Args: - specgrams (Tensor): Real spectograms (batch, channel, n_freq, time) + specgrams (Tensor): Real spectrograms (batch, channel, n_freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) Returns: - torch.Tensor: Masked scpectograms of dimensions (batch, channel, n_freq, time) + torch.Tensor: Masked spectrograms of dimensions (batch, channel, n_freq, time) """ if axis != 2 and axis != 3: @@ -699,13 +699,13 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): All examples will have the same mask interval. Args: - specgram (Tensor): Real spectogram (channel, n_freq, time) + specgram (Tensor): Real spectrogram (channel, n_freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) Returns: - torch.Tensor: Masked scpectogram of dimensions (channel, n_freq, time) + torch.Tensor: Masked spectrogram of dimensions (channel, n_freq, time) """ value = torch.rand(1) * mask_param From 6c72e8c123327a40ff5d9fd4cc61d427bc1f71d8 Mon Sep 17 00:00:00 2001 From: Kiran Sanjeevan Date: Thu, 19 Sep 2019 14:10:33 -0700 Subject: [PATCH 9/9] Typo --- test/test_functional.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index b440a852f7..811b505e73 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -307,8 +307,6 @@ def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis): mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) - torch.save(mask_specgrams, 'ex.pth') - other_axis = 2 if axis == 3 else 3 masked_columns = (mask_specgrams == mask_value).sum(other_axis)