Skip to content

Commit 1500d4e

Browse files
authored
Complex STFT transform from spectrogram (#327)
* STFT transform and function from #285 * merge options in existing functionality. * remove dimension 2 check. add test. * using ... * update spectrogram test.
1 parent 5211b84 commit 1500d4e

File tree

2 files changed

+35
-14
lines changed

2 files changed

+35
-14
lines changed

test/test_transforms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,18 @@ def test_compute_deltas_twochannel(self):
313313
computed = transform(specgram)
314314
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
315315

316+
def test_batch_spectrogram(self):
317+
waveform, sample_rate = torchaudio.load(self.test_filepath)
318+
319+
# Single then transform then batch
320+
expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)
321+
322+
# Batch then transform
323+
computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))
324+
325+
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
326+
self.assertTrue(torch.allclose(computed, expected))
327+
316328

317329
if __name__ == '__main__':
318330
unittest.main()

torchaudio/functional.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def istft(
9696
9797
Args:
9898
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
99-
column is a window. it has a size of either (channel, fft_size, n_frame, 2) or (
100-
fft_size, n_frame, 2)
99+
column is a window. it has a size of either (..., fft_size, n_frame, 2)
101100
n_fft (int): Size of Fourier transform
102101
hop_length (Optional[int]): The distance between neighboring sliding window frames.
103102
(Default: ``win_length // 4``)
@@ -218,42 +217,52 @@ def istft(
218217
def spectrogram(
219218
waveform, pad, window, n_fft, hop_length, win_length, power, normalized
220219
):
221-
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
220+
# type: (Tensor, int, Tensor, int, int, int, Optional[int], bool) -> Tensor
222221
r"""
223222
spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)
224223
225-
Create a spectrogram from a raw audio signal.
224+
Create a spectrogram or a batch of spectrograms from a raw audio signal.
225+
The spectrogram can be either magnitude-only or complex.
226226
227227
Args:
228-
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
228+
waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
229229
pad (int): Two sided padding of signal
230230
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
231231
n_fft (int): Size of FFT
232232
hop_length (int): Length of hop between STFT windows
233233
win_length (int): Window size
234234
power (int): Exponent for the magnitude spectrogram,
235235
(must be > 0) e.g., 1 for energy, 2 for power, etc.
236+
If None, then the complex spectrum is returned instead.
236237
normalized (bool): Whether to normalize by magnitude after stft
237238
238239
Returns:
239-
torch.Tensor: Dimension (channel, freq, time), where channel
240+
torch.Tensor: Dimension (..., channel, freq, time), where channel
240241
is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
241242
Fourier bins, and time is the number of window hops (n_frame).
242243
"""
243-
assert waveform.dim() == 2
244244

245245
if pad > 0:
246246
# TODO add "with torch.no_grad():" back when JIT supports it
247247
waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
248248

249+
# pack batch
250+
shape = waveform.size()
251+
waveform = waveform.reshape(-1, shape[-1])
252+
249253
# default values are consistent with librosa.core.spectrum._spectrogram
250254
spec_f = _stft(
251255
waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
252256
)
253257

258+
# unpack batch
259+
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
260+
254261
if normalized:
255262
spec_f /= window.pow(2).sum().sqrt()
256-
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor
263+
if power is not None:
264+
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor
265+
257266
return spec_f
258267

259268

@@ -431,11 +440,11 @@ def complex_norm(complex_tensor, power=1.0):
431440
r"""Compute the norm of complex tensor input.
432441
433442
Args:
434-
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
443+
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
435444
power (float): Power of the norm. (Default: `1.0`).
436445
437446
Returns:
438-
torch.Tensor: Power of the normed input tensor. Shape of `(*, )`
447+
torch.Tensor: Power of the normed input tensor. Shape of `(..., )`
439448
"""
440449
if power == 1.0:
441450
return torch.norm(complex_tensor, 2, -1)
@@ -448,21 +457,21 @@ def angle(complex_tensor):
448457
r"""Compute the angle of complex tensor input.
449458
450459
Args:
451-
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
460+
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
452461
453462
Return:
454-
torch.Tensor: Angle of a complex tensor. Shape of `(*, )`
463+
torch.Tensor: Angle of a complex tensor. Shape of `(..., )`
455464
"""
456465
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
457466

458467

459468
@torch.jit.script
460469
def magphase(complex_tensor, power=1.0):
461470
# type: (Tensor, float) -> Tuple[Tensor, Tensor]
462-
r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase.
471+
r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
463472
464473
Args:
465-
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
474+
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
466475
power (float): Power of the norm. (Default: `1.0`)
467476
468477
Returns:

0 commit comments

Comments
 (0)