Skip to content

Commit efe1644

Browse files
committed
moved the tests back to the existing test files; added them in a new class
1 parent c3d9b8e commit efe1644

File tree

6 files changed

+97
-126
lines changed

6 files changed

+97
-126
lines changed

test/torchaudio_unittest/batch_consistency_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,40 @@ def test_batch_Vol(self):
280280
# Batch then transform
281281
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
282282
self.assertEqual(computed, expected)
283+
284+
285+
class TestTransformsWithComplexTensors(common_utils.TorchaudioTestCase):
286+
def test_batch_TimeStretch(self):
287+
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
288+
waveform, _ = common_utils.load_wav(test_filepath) # (2, 278756), 44100
289+
290+
kwargs = {
291+
'n_fft': 2048,
292+
'hop_length': 512,
293+
'win_length': 2048,
294+
'window': torch.hann_window(2048),
295+
'center': True,
296+
'pad_mode': 'reflect',
297+
'normalized': True,
298+
'onesided': True,
299+
}
300+
rate = 2
301+
302+
complex_specgrams = torch.stft(waveform, **kwargs)
303+
complex_specgrams = torch.view_as_complex(complex_specgrams)
304+
305+
# Single then transform then batch
306+
expected = torchaudio.transforms.TimeStretch(
307+
fixed_rate=rate,
308+
n_freq=1025,
309+
hop_length=512,
310+
)(complex_specgrams).repeat(3, 1, 1, 1, 1)
311+
312+
# Batch then transform
313+
computed = torchaudio.transforms.TimeStretch(
314+
fixed_rate=rate,
315+
n_freq=1025,
316+
hop_length=512,
317+
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
318+
319+
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)

test/torchaudio_unittest/complex_batch_consistency_test.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

test/torchaudio_unittest/complex_librosa_compatibility_test.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

test/torchaudio_unittest/librosa_compatibility_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import unittest
44
from distutils.version import StrictVersion
5+
import parameterized
56

67
import torch
78
import torchaudio
@@ -111,6 +112,43 @@ def test_amplitude_to_DB(self):
111112
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
112113

113114

115+
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
116+
class TestFunctionalWithComplexTensors(common_utils.TorchaudioTestCase):
117+
"""Test suite for functions in `functional` module using as input tensors with complex dtypes."""
118+
@parameterized.expand([
119+
(0.5,), (1.01,), (1.3,)
120+
])
121+
def test_phase_vocoder(self, rate):
122+
torch.random.manual_seed(48)
123+
complex_specgrams = torch.randn(2, 1025, 400, dtype=torch.cdouble)
124+
hop_length = 256
125+
126+
# Due to cummulative sum, numerical error in using torch.float32 will
127+
# result in bottom right values of the stretched sectrogram to not
128+
# match with librosa.
129+
130+
phase_advance = torch.linspace(0, np.pi * hop_length,
131+
complex_specgrams.shape[-2], dtype=torch.double)[..., None]
132+
133+
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
134+
135+
# == Test shape
136+
expected_size = list(complex_specgrams.size())
137+
expected_size[-1] = int(np.ceil(expected_size[-1] / rate))
138+
139+
assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
140+
assert complex_specgrams_stretch.size() == torch.Size(expected_size)
141+
142+
# == Test values
143+
index = [0] + [slice(None)] * 2
144+
mono_complex_specgram = complex_specgrams[index].numpy()
145+
expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram,
146+
rate=rate,
147+
hop_length=hop_length)
148+
149+
self.assertEqual(complex_specgrams_stretch[index], torch.from_numpy(expected_complex_stretch))
150+
151+
114152
@pytest.mark.parametrize('complex_specgrams', [
115153
torch.randn(2, 1025, 400, 2)
116154
])

test/torchaudio_unittest/torchscript_consistency_impl.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def func(tensor):
529529
self._assert_consistency(func, waveform)
530530

531531

532-
class TransformsWithComplexDtypes(common_utils.TestBaseMixin):
532+
class TransformsMixin:
533533
"""Implements test for Transforms that are performed for different devices"""
534534
def _assert_consistency(self, transform, tensor):
535535
tensor = tensor.to(device=self.device, dtype=self.dtype)
@@ -540,18 +540,8 @@ def _assert_consistency(self, transform, tensor):
540540
ts_output = ts_transform(tensor)
541541
self.assertEqual(ts_output, output)
542542

543-
def test_TimeStretch(self):
544-
n_freq = 400
545-
hop_length = 512
546-
fixed_rate = 1.3
547-
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble)
548-
self._assert_consistency(
549-
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
550-
tensor,
551-
)
552543

553-
554-
class Transforms(common_utils.TestBaseMixin):
544+
class TransformsWithComplexDtypes(TransformsMixin, common_utils.TestBaseMixin):
555545
"""Implements test for Transforms that are performed for different devices"""
556546
def _assert_consistency(self, transform, tensor):
557547
tensor = tensor.to(device=self.device, dtype=self.dtype)
@@ -562,6 +552,18 @@ def _assert_consistency(self, transform, tensor):
562552
ts_output = ts_transform(tensor)
563553
self.assertEqual(ts_output, output)
564554

555+
def test_TimeStretch(self):
556+
n_freq = 400
557+
hop_length = 512
558+
fixed_rate = 1.3
559+
tensor = torch.rand((10, 2, n_freq, 10))
560+
self._assert_consistency(
561+
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
562+
tensor,
563+
)
564+
565+
566+
class Transforms(TransformsMixin, common_utils.TestBaseMixin):
565567
def test_Spectrogram(self):
566568
tensor = torch.rand((1, 1000))
567569
self._assert_consistency(T.Spectrogram(), tensor)

torchaudio/functional.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -468,27 +468,27 @@ def phase_vocoder(
468468
`(..., freq, ceil(time/rate), complex=2)` or
469469
a complex dtype and dimension of `(..., freq, ceil(time/rate))`.
470470
471-
Example - old API
471+
Example - New API (using tensors with complex dtype)
472472
>>> freq, hop_length = 1025, 512
473-
>>> # (channel, freq, time, complex=2)
474-
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
473+
>>> # (channel, freq, time)
474+
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
475475
>>> rate = 1.3 # Speed up by 30%
476476
>>> phase_advance = torch.linspace(
477477
>>> 0, math.pi * hop_length, freq)[..., None]
478478
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
479479
>>> x.shape # with 231 == ceil(300 / 1.3)
480-
torch.Size([2, 1025, 231, 2])
480+
torch.Size([2, 1025, 231])
481481
482-
Example - New API (using tensors with complex dtype)
482+
Example - Old API (using real tensors with shape (..., complex=2))
483483
>>> freq, hop_length = 1025, 512
484-
>>> # (channel, freq, time)
485-
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
484+
>>> # (channel, freq, time, complex=2)
485+
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
486486
>>> rate = 1.3 # Speed up by 30%
487487
>>> phase_advance = torch.linspace(
488488
>>> 0, math.pi * hop_length, freq)[..., None]
489489
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
490490
>>> x.shape # with 231 == ceil(300 / 1.3)
491-
torch.Size([2, 1025, 231])
491+
torch.Size([2, 1025, 231, 2])
492492
"""
493493
use_complex = complex_specgrams.is_complex()
494494
shape = complex_specgrams.size()

0 commit comments

Comments
 (0)