@@ -96,7 +96,7 @@ def istft(
9696
9797 Args:
9898 stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
99- column is a window. it has a size of either (..., fft_size, n_frame, 2)
99+ column is a window. It has a size of either (..., fft_size, n_frame, 2)
100100 n_fft (int): Size of Fourier transform
101101 hop_length (Optional[int]): The distance between neighboring sliding window frames.
102102 (Default: ``win_length // 4``)
@@ -229,7 +229,7 @@ def spectrogram(
229229 The spectrogram can be either magnitude-only or complex.
230230
231231 Args:
232- waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
232+ waveform (torch.Tensor): Tensor of audio of dimension (..., time)
233233 pad (int): Two sided padding of signal
234234 window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
235235 n_fft (int): Size of FFT
@@ -241,8 +241,8 @@ def spectrogram(
241241 normalized (bool): Whether to normalize by magnitude after stft
242242
243243 Returns:
244- torch.Tensor: Dimension (..., channel, freq, time), where channel
245- is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
244+ torch.Tensor: Dimension (..., freq, time), freq is
245+ ``n_fft // 2 + 1`` and ``n_fft`` is the number of
246246 Fourier bins, and time is the number of window hops (n_frame).
247247 """
248248
@@ -613,7 +613,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
613613 https://en.wikipedia.org/wiki/Digital_biquad_filter
614614
615615 Args:
616- waveform (torch.Tensor): audio waveform of dimension of `(channel , time)`
616+ waveform (torch.Tensor): audio waveform of dimension of `(... , time)`
617617 b0 (float): numerator coefficient of current input, x[n]
618618 b1 (float): numerator coefficient of input one time step ago x[n-1]
619619 b2 (float): numerator coefficient of input two time steps ago x[n-2]
@@ -622,7 +622,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
622622 a2 (float): denominator coefficient of current output y[n-2]
623623
624624 Returns:
625- output_waveform (torch.Tensor): Dimension of `(channel , time)`
625+ output_waveform (torch.Tensor): Dimension of `(... , time)`
626626 """
627627
628628 device = waveform .device
@@ -646,13 +646,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
646646 r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
647647
648648 Args:
649- waveform (torch.Tensor): audio waveform of dimension of `(channel , time)`
649+ waveform (torch.Tensor): audio waveform of dimension of `(... , time)`
650650 sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
651651 cutoff_freq (float): filter cutoff frequency
652652 Q (float): https://en.wikipedia.org/wiki/Q_factor
653653
654654 Returns:
655- output_waveform (torch.Tensor): Dimension of `(channel , time)`
655+ output_waveform (torch.Tensor): Dimension of `(... , time)`
656656 """
657657
658658 GAIN = 1.
@@ -675,13 +675,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
675675 r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
676676
677677 Args:
678- waveform (torch.Tensor): audio waveform of dimension of `(channel , time)`
678+ waveform (torch.Tensor): audio waveform of dimension of `(... , time)`
679679 sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
680680 cutoff_freq (float): filter cutoff frequency
681681 Q (float): https://en.wikipedia.org/wiki/Q_factor
682682
683683 Returns:
684- output_waveform (torch.Tensor): Dimension of `(channel , time)`
684+ output_waveform (torch.Tensor): Dimension of `(... , time)`
685685 """
686686
687687 GAIN = 1.
@@ -704,14 +704,14 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
704704 r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation.
705705
706706 Args:
707- waveform (torch.Tensor): audio waveform of dimension of `(channel , time)`
707+ waveform (torch.Tensor): audio waveform of dimension of `(... , time)`
708708 sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
709709 center_freq (float): filter's central frequency
710710 gain (float): desired gain at the boost (or attenuation) in dB
711711 q_factor (float): https://en.wikipedia.org/wiki/Q_factor
712712
713713 Returns:
714- output_waveform (torch.Tensor): Dimension of `(channel , time)`
714+ output_waveform (torch.Tensor): Dimension of `(... , time)`
715715 """
716716 w0 = 2 * math .pi * center_freq / sample_rate
717717 A = math .exp (gain / 40.0 * math .log (10 ))
@@ -800,7 +800,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
800800 # unpack batch
801801 specgram = specgram .reshape (shape [:- 2 ] + specgram .shape [- 2 :])
802802
803- return specgram . reshape ( shape [: - 2 ] + specgram . shape [ - 2 :])
803+ return specgram
804804
805805
806806def compute_deltas (specgram , win_length = 5 , mode = "replicate" ):
@@ -860,7 +860,7 @@ def gain(waveform, gain_db=1.0):
860860 r"""Apply amplification or attenuation to the whole waveform.
861861
862862 Args:
863- waveform (torch.Tensor): Tensor of audio of dimension (channel , time).
863+ waveform (torch.Tensor): Tensor of audio of dimension (... , time).
864864 gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).
865865
866866 Returns:
@@ -913,7 +913,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
913913 The relationship of probabilities of results follows a bell-shaped,
914914 or Gaussian curve, typical of dither generated by analog sources.
915915 Args:
916- waveform (torch.Tensor): Tensor of audio of dimension (channel , time)
916+ waveform (torch.Tensor): Tensor of audio of dimension (... , time)
917917 probability_density_function (string): The density function of a
918918 continuous random variable (Default: `TPDF`)
919919 Options: Triangular Probability Density Function - `TPDF`
@@ -922,6 +922,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
922922 Returns:
923923 torch.Tensor: waveform dithered with TPDF
924924 """
925+
926+ # pack batch
925927 shape = waveform .size ()
926928 waveform = waveform .reshape (- 1 , shape [- 1 ])
927929
@@ -961,6 +963,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
961963
962964 quantised_signal_scaled = torch .round (signal_scaled_dis )
963965 quantised_signal = quantised_signal_scaled / down_scaling
966+
967+ # unpack batch
964968 return quantised_signal .reshape (shape [:- 1 ] + quantised_signal .shape [- 1 :])
965969
966970
@@ -970,7 +974,7 @@ def dither(waveform, density_function="TPDF", noise_shaping=False):
970974 particular bit-depth by eliminating nonlinear truncation distortion
971975 (i.e. adding minimally perceived noise to mask distortion caused by quantization).
972976 Args:
973- waveform (torch.Tensor): Tensor of audio of dimension (channel , time)
977+ waveform (torch.Tensor): Tensor of audio of dimension (... , time)
974978 density_function (string): The density function of a
975979 continuous random variable (Default: `TPDF`)
976980 Options: Triangular Probability Density Function - `TPDF`
0 commit comments