Skip to content

Commit dff919c

Browse files
committed
moved the tests back to the existing test files; added them in a new class
1 parent c3d9b8e commit dff919c

File tree

6 files changed

+92
-127
lines changed

6 files changed

+92
-127
lines changed

test/torchaudio_unittest/batch_consistency_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,39 @@ def test_batch_Vol(self):
280280
# Batch then transform
281281
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
282282
self.assertEqual(computed, expected)
283+
284+
class TestTransformsWithComplexTensors(common_utils.TorchaudioTestCase):
285+
def test_batch_TimeStretch(self):
286+
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
287+
waveform, _ = common_utils.load_wav(test_filepath) # (2, 278756), 44100
288+
289+
kwargs = {
290+
'n_fft': 2048,
291+
'hop_length': 512,
292+
'win_length': 2048,
293+
'window': torch.hann_window(2048),
294+
'center': True,
295+
'pad_mode': 'reflect',
296+
'normalized': True,
297+
'onesided': True,
298+
}
299+
rate = 2
300+
301+
complex_specgrams = torch.stft(waveform, **kwargs)
302+
complex_specgrams = torch.view_as_complex(complex_specgrams)
303+
304+
# Single then transform then batch
305+
expected = torchaudio.transforms.TimeStretch(
306+
fixed_rate=rate,
307+
n_freq=1025,
308+
hop_length=512,
309+
)(complex_specgrams).repeat(3, 1, 1, 1, 1)
310+
311+
# Batch then transform
312+
computed = torchaudio.transforms.TimeStretch(
313+
fixed_rate=rate,
314+
n_freq=1025,
315+
hop_length=512,
316+
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
317+
318+
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)

test/torchaudio_unittest/complex_batch_consistency_test.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

test/torchaudio_unittest/complex_librosa_compatibility_test.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

test/torchaudio_unittest/librosa_compatibility_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,41 @@ def test_amplitude_to_DB(self):
110110

111111
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
112112

113+
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
114+
class TestFunctionalWithComplexTensors(common_utils.TorchaudioTestCase):
115+
"""Test suite for functions in `functional` module using as input tensors with complex dtypes."""
116+
@parameterized.expand([
117+
(0.5,), (1.01,), (1.3,)
118+
])
119+
def test_phase_vocoder(self, rate):
120+
torch.random.manual_seed(48)
121+
complex_specgrams = torch.randn(2, 1025, 400, dtype=torch.cdouble)
122+
hop_length = 256
123+
124+
# Due to cummulative sum, numerical error in using torch.float32 will
125+
# result in bottom right values of the stretched sectrogram to not
126+
# match with librosa.
127+
128+
phase_advance = torch.linspace(0, np.pi * hop_length,
129+
complex_specgrams.shape[-2], dtype=torch.double)[..., None]
130+
131+
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
132+
133+
# == Test shape
134+
expected_size = list(complex_specgrams.size())
135+
expected_size[-1] = int(np.ceil(expected_size[-1] / rate))
136+
137+
assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
138+
assert complex_specgrams_stretch.size() == torch.Size(expected_size)
139+
140+
# == Test values
141+
index = [0] + [slice(None)] * 2
142+
mono_complex_specgram = complex_specgrams[index].numpy()
143+
expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram,
144+
rate=rate,
145+
hop_length=hop_length)
146+
147+
self.assertEqual(complex_specgrams_stretch[index], torch.from_numpy(expected_complex_stretch))
113148

114149
@pytest.mark.parametrize('complex_specgrams', [
115150
torch.randn(2, 1025, 400, 2)

test/torchaudio_unittest/torchscript_consistency_impl.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def func(tensor):
529529
self._assert_consistency(func, waveform)
530530

531531

532-
class TransformsWithComplexDtypes(common_utils.TestBaseMixin):
532+
class TransformsMixin:
533533
"""Implements test for Transforms that are performed for different devices"""
534534
def _assert_consistency(self, transform, tensor):
535535
tensor = tensor.to(device=self.device, dtype=self.dtype)
@@ -540,18 +540,7 @@ def _assert_consistency(self, transform, tensor):
540540
ts_output = ts_transform(tensor)
541541
self.assertEqual(ts_output, output)
542542

543-
def test_TimeStretch(self):
544-
n_freq = 400
545-
hop_length = 512
546-
fixed_rate = 1.3
547-
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cdouble)
548-
self._assert_consistency(
549-
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
550-
tensor,
551-
)
552-
553-
554-
class Transforms(common_utils.TestBaseMixin):
543+
class TransformsWithComplexDtypes(TransformsMixin, common_utils.TestBaseMixin):
555544
"""Implements test for Transforms that are performed for different devices"""
556545
def _assert_consistency(self, transform, tensor):
557546
tensor = tensor.to(device=self.device, dtype=self.dtype)
@@ -562,6 +551,17 @@ def _assert_consistency(self, transform, tensor):
562551
ts_output = ts_transform(tensor)
563552
self.assertEqual(ts_output, output)
564553

554+
def test_TimeStretch(self):
555+
n_freq = 400
556+
hop_length = 512
557+
fixed_rate = 1.3
558+
tensor = torch.rand((10, 2, n_freq, 10))
559+
self._assert_consistency(
560+
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
561+
tensor,
562+
)
563+
564+
class Transforms(TransformsMixin, common_utils.TestBaseMixin):
565565
def test_Spectrogram(self):
566566
tensor = torch.rand((1, 1000))
567567
self._assert_consistency(T.Spectrogram(), tensor)

torchaudio/functional.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -468,27 +468,27 @@ def phase_vocoder(
468468
`(..., freq, ceil(time/rate), complex=2)` or
469469
a complex dtype and dimension of `(..., freq, ceil(time/rate))`.
470470
471-
Example - old API
471+
Example - New API (using tensors with complex dtype)
472472
>>> freq, hop_length = 1025, 512
473-
>>> # (channel, freq, time, complex=2)
474-
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
473+
>>> # (channel, freq, time)
474+
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
475475
>>> rate = 1.3 # Speed up by 30%
476476
>>> phase_advance = torch.linspace(
477477
>>> 0, math.pi * hop_length, freq)[..., None]
478478
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
479479
>>> x.shape # with 231 == ceil(300 / 1.3)
480-
torch.Size([2, 1025, 231, 2])
480+
torch.Size([2, 1025, 231])
481481
482-
Example - New API (using tensors with complex dtype)
482+
Example - Old API (using real tensors with shape (..., complex=2))
483483
>>> freq, hop_length = 1025, 512
484-
>>> # (channel, freq, time)
485-
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
484+
>>> # (channel, freq, time, complex=2)
485+
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
486486
>>> rate = 1.3 # Speed up by 30%
487487
>>> phase_advance = torch.linspace(
488488
>>> 0, math.pi * hop_length, freq)[..., None]
489489
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
490490
>>> x.shape # with 231 == ceil(300 / 1.3)
491-
torch.Size([2, 1025, 231])
491+
torch.Size([2, 1025, 231, 2])
492492
"""
493493
use_complex = complex_specgrams.is_complex()
494494
shape = complex_specgrams.size()

0 commit comments

Comments
 (0)