@@ -96,8 +96,7 @@ def istft(
9696
9797 Args:
9898 stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
99- column is a window. it has a size of either (channel, fft_size, n_frame, 2) or (
100- fft_size, n_frame, 2)
99+ column is a window. it has a size of either (..., fft_size, n_frame, 2)
101100 n_fft (int): Size of Fourier transform
102101 hop_length (Optional[int]): The distance between neighboring sliding window frames.
103102 (Default: ``win_length // 4``)
@@ -218,42 +217,52 @@ def istft(
218217def spectrogram (
219218 waveform , pad , window , n_fft , hop_length , win_length , power , normalized
220219):
221- # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
220+ # type: (Tensor, int, Tensor, int, int, int, Optional[ int] , bool) -> Tensor
222221 r"""
223222 spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)
224223
225- Create a spectrogram from a raw audio signal.
224+ Create a spectrogram or a batch of spectrograms from a raw audio signal.
225+ The spectrogram can be either magnitude-only or complex.
226226
227227 Args:
228- waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
228+ waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
229229 pad (int): Two sided padding of signal
230230 window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
231231 n_fft (int): Size of FFT
232232 hop_length (int): Length of hop between STFT windows
233233 win_length (int): Window size
234234 power (int): Exponent for the magnitude spectrogram,
235235 (must be > 0) e.g., 1 for energy, 2 for power, etc.
236+ If None, then the complex spectrum is returned instead.
236237 normalized (bool): Whether to normalize by magnitude after stft
237238
238239 Returns:
239- torch.Tensor: Dimension (channel, freq, time), where channel
240+ torch.Tensor: Dimension (..., channel, freq, time), where channel
240241 is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
241242 Fourier bins, and time is the number of window hops (n_frame).
242243 """
243- assert waveform .dim () == 2
244244
245245 if pad > 0 :
246246 # TODO add "with torch.no_grad():" back when JIT supports it
247247 waveform = torch .nn .functional .pad (waveform , (pad , pad ), "constant" )
248248
249+ # pack batch
250+ shape = waveform .size ()
251+ waveform = waveform .reshape (- 1 , shape [- 1 ])
252+
249253 # default values are consistent with librosa.core.spectrum._spectrogram
250254 spec_f = _stft (
251255 waveform , n_fft , hop_length , win_length , window , True , "reflect" , False , True
252256 )
253257
258+ # unpack batch
259+ spec_f = spec_f .reshape (shape [:- 1 ] + spec_f .shape [- 3 :])
260+
254261 if normalized :
255262 spec_f /= window .pow (2 ).sum ().sqrt ()
256- spec_f = spec_f .pow (power ).sum (- 1 ) # get power of "complex" tensor
263+ if power is not None :
264+ spec_f = spec_f .pow (power ).sum (- 1 ) # get power of "complex" tensor
265+
257266 return spec_f
258267
259268
@@ -431,11 +440,11 @@ def complex_norm(complex_tensor, power=1.0):
431440 r"""Compute the norm of complex tensor input.
432441
433442 Args:
434- complex_tensor (torch.Tensor): Tensor shape of `(* , complex=2)`
443+ complex_tensor (torch.Tensor): Tensor shape of `(... , complex=2)`
435444 power (float): Power of the norm. (Default: `1.0`).
436445
437446 Returns:
438- torch.Tensor: Power of the normed input tensor. Shape of `(* , )`
447+ torch.Tensor: Power of the normed input tensor. Shape of `(... , )`
439448 """
440449 if power == 1.0 :
441450 return torch .norm (complex_tensor , 2 , - 1 )
@@ -448,21 +457,21 @@ def angle(complex_tensor):
448457 r"""Compute the angle of complex tensor input.
449458
450459 Args:
451- complex_tensor (torch.Tensor): Tensor shape of `(* , complex=2)`
460+ complex_tensor (torch.Tensor): Tensor shape of `(... , complex=2)`
452461
453462 Return:
454- torch.Tensor: Angle of a complex tensor. Shape of `(* , )`
463+ torch.Tensor: Angle of a complex tensor. Shape of `(... , )`
455464 """
456465 return torch .atan2 (complex_tensor [..., 1 ], complex_tensor [..., 0 ])
457466
458467
459468@torch .jit .script
460469def magphase (complex_tensor , power = 1.0 ):
461470 # type: (Tensor, float) -> Tuple[Tensor, Tensor]
462- r"""Separate a complex-valued spectrogram with shape `(* , 2)` into its magnitude and phase.
471+ r"""Separate a complex-valued spectrogram with shape `(... , 2)` into its magnitude and phase.
463472
464473 Args:
465- complex_tensor (torch.Tensor): Tensor shape of `(* , complex=2)`
474+ complex_tensor (torch.Tensor): Tensor shape of `(... , complex=2)`
466475 power (float): Power of the norm. (Default: `1.0`)
467476
468477 Returns:
0 commit comments