Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 66 additions & 29 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,17 @@ def test_compute_deltas_randn(self):
win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time)
computed = F.compute_deltas(specgram, win_length=win_length)

self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)

def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)

# Single then transform then batch
expected = F.detect_pitch_frequency(waveform, sample_rate)
expected = expected.unsqueeze(0).repeat(3, 1, 1)

# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = F.detect_pitch_frequency(waveform, sample_rate)

self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_jit_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)

def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
Expand All @@ -106,22 +101,13 @@ def _test_istft_is_inverse_of_stft(self, kwargs):
for data_size in self.data_sizes:
for i in range(self.number_of_trials):

# Non-batch
sound = common_utils.random_float_tensor(i, data_size)

stft = torch.stft(sound, **kwargs)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)

self._compare_estimate(sound, estimate)

# Batch
stft = torch.stft(sound, **kwargs)
stft = stft.repeat(3, 1, 1, 1, 1)
sound = sound.repeat(3, 1, 1)

estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
self._compare_estimate(sound, estimate)

def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided
kwargs1 = {
Expand Down Expand Up @@ -338,6 +324,16 @@ def test_linearity_of_istft4(self):
data_size = (2, 7, 3, 2)
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)

def test_batch_istft(self):

stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])

self._test_batch(F.istft, stft, n_fft=4, length=4)

def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
# Using a decorator here causes parametrize to fail on Python 2
if not IMPORT_LIBROSA:
Expand Down Expand Up @@ -438,22 +434,63 @@ def test_pitch(self):
self.assertFalse(s)

# Convert to stereo and batch for testing purposes
freq = freq.repeat(3, 2, 1, 1)
waveform = waveform.repeat(3, 2, 1, 1)
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)

freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
def _test_batch_shape(self, functional, tensor, *args, **kwargs):

assert torch.allclose(freq, freq2, atol=1e-5)
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol

def _test_batch(self, functional):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol

# Single then transform then batch
expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)

# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = functional(waveform)
torch.random.manual_seed(42)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.unsqueeze(0).unsqueeze(0)

# 1-Batch then transform

tensors = tensor.unsqueeze(0).unsqueeze(0)

torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)

self._compare_estimate(computed, expected, **kwargs_compare)

return tensors, expected

def _test_batch(self, functional, tensor, *args, **kwargs):

tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs)

kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol

if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol

# 3-Batch then transform

ind = [3] + [1] * (int(tensors.dim()) - 1)
tensors = tensor.repeat(*ind)

ind = [3] + [1] * (int(expected.dim()) - 1)
expected = expected.repeat(*ind)

torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)


def _num_stft_bins(signal_len, fft_len, hop_length, pad):
Expand Down
37 changes: 37 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,19 @@ def test_compute_deltas_twochannel(self):
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786)

# Single then transform then batch
expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)

# Batch then transform
computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))

# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)

Expand Down Expand Up @@ -422,6 +435,30 @@ def test_batch_spectrogram(self):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def test_batch_melspectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)

# Single then transform then batch
expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)

# Batch then transform
computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))

self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def test_batch_mfcc(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)

# Single then transform then batch
expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

# Batch then transform
computed = transforms.MFCC()(waveform.repeat(3, 1, 1))

self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

def test_scriptmodule_TimeStretch(self):
n_freq = 400
hop_length = 512
Expand Down
36 changes: 20 additions & 16 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def istft(

Args:
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. it has a size of either (..., fft_size, n_frame, 2)
column is a window. It has a size of either (..., fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``)
Expand Down Expand Up @@ -229,7 +229,7 @@ def spectrogram(
The spectrogram can be either magnitude-only or complex.

Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
pad (int): Two sided padding of signal
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
Expand All @@ -241,8 +241,8 @@ def spectrogram(
normalized (bool): Whether to normalize by magnitude after stft

Returns:
torch.Tensor: Dimension (..., channel, freq, time), where channel
is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
torch.Tensor: Dimension (..., freq, time), freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""

Expand Down Expand Up @@ -613,7 +613,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
https://en.wikipedia.org/wiki/Digital_biquad_filter

Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
b0 (float): numerator coefficient of current input, x[n]
b1 (float): numerator coefficient of input one time step ago x[n-1]
b2 (float): numerator coefficient of input two time steps ago x[n-2]
Expand All @@ -622,7 +622,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
a2 (float): denominator coefficient of current output y[n-2]

Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`
output_waveform (torch.Tensor): Dimension of `(..., time)`
"""

device = waveform.device
Expand All @@ -646,13 +646,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.

Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor

Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`
output_waveform (torch.Tensor): Dimension of `(..., time)`
"""

GAIN = 1.
Expand All @@ -675,13 +675,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.

Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor

Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`
output_waveform (torch.Tensor): Dimension of `(..., time)`
"""

GAIN = 1.
Expand All @@ -704,14 +704,14 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation.

Args:
waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
center_freq (float): filter's central frequency
gain (float): desired gain at the boost (or attenuation) in dB
q_factor (float): https://en.wikipedia.org/wiki/Q_factor

Returns:
output_waveform (torch.Tensor): Dimension of `(channel, time)`
output_waveform (torch.Tensor): Dimension of `(..., time)`
"""
w0 = 2 * math.pi * center_freq / sample_rate
A = math.exp(gain / 40.0 * math.log(10))
Expand Down Expand Up @@ -800,7 +800,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
# unpack batch
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])

return specgram.reshape(shape[:-2] + specgram.shape[-2:])
return specgram


def compute_deltas(specgram, win_length=5, mode="replicate"):
Expand Down Expand Up @@ -860,7 +860,7 @@ def gain(waveform, gain_db=1.0):
r"""Apply amplification or attenuation to the whole waveform.

Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).

Returns:
Expand Down Expand Up @@ -913,7 +913,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
The relationship of probabilities of results follows a bell-shaped,
or Gaussian curve, typical of dither generated by analog sources.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
probability_density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF`
Expand All @@ -922,6 +922,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
Returns:
torch.Tensor: waveform dithered with TPDF
"""

# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])

Expand Down Expand Up @@ -961,6 +963,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):

quantised_signal_scaled = torch.round(signal_scaled_dis)
quantised_signal = quantised_signal_scaled / down_scaling

# unpack batch
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])


Expand All @@ -970,7 +974,7 @@ def dither(waveform, density_function="TPDF", noise_shaping=False):
particular bit-depth by eliminating nonlinear truncation distortion
(i.e. adding minimally perceived noise to mask distortion caused by quantization).
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF`
Expand Down
Loading