diff --git a/test/test_transforms.py b/test/test_transforms.py index d78c5b6000..a438c4481c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -313,6 +313,18 @@ def test_compute_deltas_twochannel(self): computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + def test_batch_spectrogram(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) + + # Single then transform then batch + expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1) + + # Batch then transform + computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1)) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + if __name__ == '__main__': unittest.main() diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 7539f371d4..4397ddd51b 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -96,8 +96,7 @@ def istft( Args: stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each - column is a window. it has a size of either (channel, fft_size, n_frame, 2) or ( - fft_size, n_frame, 2) + column is a window. it has a size of either (..., fft_size, n_frame, 2) n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. (Default: ``win_length // 4``) @@ -218,14 +217,15 @@ def istft( def spectrogram( waveform, pad, window, n_fft, hop_length, win_length, power, normalized ): - # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor + # type: (Tensor, int, Tensor, int, int, int, Optional[int], bool) -> Tensor r""" spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized) - Create a spectrogram from a raw audio signal. + Create a spectrogram or a batch of spectrograms from a raw audio signal. + The spectrogram can be either magnitude-only or complex. Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time) pad (int): Two sided padding of signal window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window n_fft (int): Size of FFT @@ -233,27 +233,36 @@ def spectrogram( win_length (int): Window size power (int): Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. + If None, then the complex spectrum is returned instead. normalized (bool): Whether to normalize by magnitude after stft Returns: - torch.Tensor: Dimension (channel, freq, time), where channel + torch.Tensor: Dimension (..., channel, freq, time), where channel is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frame). """ - assert waveform.dim() == 2 if pad > 0: # TODO add "with torch.no_grad():" back when JIT supports it waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + # default values are consistent with librosa.core.spectrum._spectrogram spec_f = _stft( waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True ) + # unpack batch + spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:]) + if normalized: spec_f /= window.pow(2).sum().sqrt() - spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor + if power is not None: + spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor + return spec_f @@ -431,11 +440,11 @@ def complex_norm(complex_tensor, power=1.0): r"""Compute the norm of complex tensor input. Args: - complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)` power (float): Power of the norm. (Default: `1.0`). Returns: - torch.Tensor: Power of the normed input tensor. Shape of `(*, )` + torch.Tensor: Power of the normed input tensor. Shape of `(..., )` """ if power == 1.0: return torch.norm(complex_tensor, 2, -1) @@ -448,10 +457,10 @@ def angle(complex_tensor): r"""Compute the angle of complex tensor input. Args: - complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)` Return: - torch.Tensor: Angle of a complex tensor. Shape of `(*, )` + torch.Tensor: Angle of a complex tensor. Shape of `(..., )` """ return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) @@ -459,10 +468,10 @@ def angle(complex_tensor): @torch.jit.script def magphase(complex_tensor, power=1.0): # type: (Tensor, float) -> Tuple[Tensor, Tensor] - r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase. + r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase. Args: - complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)` power (float): Power of the norm. (Default: `1.0`) Returns: