diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 78e7a5b759..496de76357 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -469,13 +469,13 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): factor of ``rate``. Args: - complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)` + complex_specgrams (torch.Tensor): Dimension of `(..., freq, time, complex=2)` rate (float): Speed-up factor phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension of (freq, 1) Returns: - complex_specgrams_stretch (torch.Tensor): Dimension of `(channel, + complex_specgrams_stretch (torch.Tensor): Dimension of `(..., freq, ceil(time/rate), complex=2)` Example @@ -490,6 +490,10 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): torch.Size([2, 1025, 231, 2]) """ + # pack batch + shape = complex_specgrams.size() + complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) + time_steps = torch.arange(0, complex_specgrams.size(-2), rate, @@ -527,6 +531,9 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) + # unpack batch + complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:]) + return complex_specgrams_stretch @@ -775,6 +782,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): torch.Tensor: Masked spectrogram of dimensions (channel, freq, time) """ + # pack batch + shape = specgram.size() + specgram = specgram.reshape([-1] + list(shape[-2:])) + value = torch.rand(1) * mask_param min_value = torch.rand(1) * (specgram.size(axis) - value) @@ -789,7 +800,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): else: raise ValueError('Only Frequency and Time masking are supported') - return specgram + # unpack batch + specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:]) + + return specgram.reshape(shape[:-2] + specgram.shape[-2:]) def compute_deltas(specgram, win_length=5, mode="replicate"): diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 3d0269fe1b..ed0fcece25 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -380,9 +380,9 @@ def __init__(self, power=1.0): def forward(self, complex_tensor): r""" Args: - complex_tensor (Tensor): Tensor shape of `(*, complex=2)` + complex_tensor (Tensor): Tensor shape of `(..., complex=2)` Returns: - Tensor: norm of the input tensor, shape of `(*, )` + Tensor: norm of the input tensor, shape of `(..., )` """ return F.complex_norm(complex_tensor, self.power) @@ -438,14 +438,14 @@ def forward(self, complex_specgrams, overriding_rate=None): # type: (Tensor, Optional[float]) -> Tensor r""" Args: - complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2) + complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2) overriding_rate (float or None): speed up to apply to this batch. If no rate is passed, use ``self.fixed_rate`` Returns: - (Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2) + (Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2) """ - assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)" + assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)" if overriding_rate is None: rate = self.fixed_rate @@ -458,16 +458,12 @@ def forward(self, complex_specgrams, overriding_rate=None): if rate == 1.0: return complex_specgrams - shape = complex_specgrams.size() - complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) - complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance) - - return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) + return F.phase_vocoder(complex_specgrams, rate, self.phase_advance) class _AxisMasking(torch.nn.Module): - r""" - Apply masking to a spectrogram. + r"""Apply masking to a spectrogram. + Args: mask_param (int): Maximum possible length of the mask axis: What dimension the mask is applied on @@ -486,26 +482,22 @@ def forward(self, specgram, mask_value=0.): # type: (Tensor, float) -> Tensor r""" Args: - specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time) + specgram (torch.Tensor): Tensor of dimension (..., freq, time) Returns: - torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time) + torch.Tensor: Masked spectrogram of dimensions (..., freq, time) """ # if iid_masks flag marked and specgram has a batch dimension if self.iid_masks and specgram.dim() == 4: return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) else: - shape = specgram.size() - specgram = specgram.reshape([-1] + list(shape[-2:])) - specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) - - return specgram.reshape(shape[:-2] + specgram.shape[-2:]) + return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) class FrequencyMasking(_AxisMasking): - r""" - Apply masking to a spectrogram in the frequency domain. + r"""Apply masking to a spectrogram in the frequency domain. + Args: freq_mask_param (int): maximum possible length of the mask. Indices uniformly sampled from [0, freq_mask_param). @@ -518,8 +510,8 @@ def __init__(self, freq_mask_param, iid_masks=False): class TimeMasking(_AxisMasking): - r""" - Apply masking to a spectrogram in the time domain. + r"""Apply masking to a spectrogram in the time domain. + Args: time_mask_param (int): maximum possible length of the mask. Indices uniformly sampled from [0, time_mask_param).