Skip to content

Commit 0433b7a

Browse files
authored
Make F.phase_vocoder and T.TimeStretch handle complex dtype (#1410)
1. `F.phase_vocoder` accepts Tensor with complex dtype. * The implementation path has been updated from #758 so that they share the same code path by internally converting the input Tensor to complex dtype and performing all the operation on top of it. * Adopted `torch.polar` for simpler Tensor generation from magnitude and angle. 2. Updated tests * librosa compatibility test for complex dtype and pseudo complex dtype * Extracted the output shape check test and moved it to functional so that it will be tested on all the combination of `{CPU | CUDA} x {complex64 | complex128}` * TorchScript compatibility test for `F.phase_vocoder` and `T.TimeStretch`. * batch consistency test for `T.TimeStretch`.
1 parent a6cdd6c commit 0433b7a

13 files changed

+291
-117
lines changed

test/torchaudio_unittest/functional/functional_cpu_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
skipIfNoSox,
1515
)
1616

17-
from .functional_impl import Lfilter, Spectrogram
17+
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
1818

1919

2020
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
@@ -41,6 +41,18 @@ class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
4141
device = torch.device('cpu')
4242

4343

44+
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
45+
complex_dtype = torch.complex64
46+
real_dtype = torch.float32
47+
device = torch.device('cpu')
48+
49+
50+
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
51+
complex_dtype = torch.complex128
52+
real_dtype = torch.float64
53+
device = torch.device('cpu')
54+
55+
4456
class TestCreateFBMatrix(common_utils.TorchaudioTestCase):
4557
def test_no_warning_high_n_freq(self):
4658
with warnings.catch_warnings(record=True) as w:

test/torchaudio_unittest/functional/functional_cuda_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33

44
from torchaudio_unittest import common_utils
5-
from .functional_impl import Lfilter, Spectrogram
5+
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
66

77

88
@common_utils.skipIfNoCuda
@@ -31,3 +31,17 @@ class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
3131
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
3232
dtype = torch.float64
3333
device = torch.device('cuda')
34+
35+
36+
@common_utils.skipIfNoCuda
37+
class TestFunctionalComplex64(FunctionalComplex, common_utils.PytorchTestCase):
38+
complex_dtype = torch.complex64
39+
real_dtype = torch.float32
40+
device = torch.device('cuda')
41+
42+
43+
@common_utils.skipIfNoCuda
44+
class TestFunctionalComplex128(FunctionalComplex, common_utils.PytorchTestCase):
45+
complex_dtype = torch.complex64
46+
real_dtype = torch.float32
47+
device = torch.device('cuda')

test/torchaudio_unittest/functional/functional_impl.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import torch
33
import torchaudio.functional as F
44
from parameterized import parameterized
5+
import numpy as np
56
from scipy import signal
67

78
from torchaudio_unittest import common_utils
9+
from torchaudio_unittest.common_utils import nested_params
810

911

1012
class Lfilter(common_utils.TestBaseMixin):
@@ -89,3 +91,39 @@ def test_grad_at_zero(self, power):
8991
)
9092
spec.sum().backward()
9193
assert not x.grad.isnan().sum()
94+
95+
96+
class FunctionalComplex(common_utils.TestBaseMixin):
97+
complex_dtype = None
98+
real_dtype = None
99+
device = None
100+
101+
@nested_params(
102+
[0.5, 1.01, 1.3],
103+
[True, False],
104+
)
105+
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
106+
"""Verify the output shape of phase vocoder"""
107+
hop_length = 256
108+
num_freq = 1025
109+
num_frames = 400
110+
batch_size = 2
111+
112+
torch.random.manual_seed(42)
113+
spec = torch.randn(
114+
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
115+
if test_pseudo_complex:
116+
spec = torch.view_as_real(spec)
117+
118+
phase_advance = torch.linspace(
119+
0,
120+
np.pi * hop_length,
121+
num_freq,
122+
dtype=self.real_dtype, device=self.device)[..., None]
123+
124+
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
125+
126+
assert spec.dim() == spec_stretch.dim()
127+
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
128+
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
129+
assert output_shape == expected_shape

test/torchaudio_unittest/functional/librosa_compatibility_test.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
import unittest
32
from distutils.version import StrictVersion
43

@@ -15,6 +14,9 @@
1514
import librosa
1615

1716
from torchaudio_unittest import common_utils
17+
from torchaudio_unittest.common_utils import (
18+
nested_params,
19+
)
1820

1921

2022
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
@@ -130,45 +132,36 @@ def test_resample(self):
130132

131133

132134
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
133-
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
134-
@parameterized.expand(list(itertools.product(
135-
[(2, 1025, 400, 2)],
135+
class TestFunctionalComplex(common_utils.TorchaudioTestCase):
136+
@nested_params(
136137
[0.5, 1.01, 1.3],
137-
[256]
138-
)))
139-
def test_phase_vocoder(self, shape, rate, hop_length):
138+
[True, False],
139+
)
140+
def test_phase_vocoder(self, rate, test_pseudo_complex):
141+
hop_length = 256
142+
num_freq = 1025
143+
num_frames = 400
144+
torch.random.manual_seed(42)
145+
140146
# Due to cummulative sum, numerical error in using torch.float32 will
141147
# result in bottom right values of the stretched sectrogram to not
142148
# match with librosa.
143-
torch.random.manual_seed(42)
144-
complex_specgrams = torch.randn(*shape)
145-
complex_specgrams = complex_specgrams.type(torch.float64)
149+
spec = torch.randn(num_freq, num_frames, dtype=torch.complex128)
146150
phase_advance = torch.linspace(
147151
0,
148152
np.pi * hop_length,
149-
complex_specgrams.shape[-3],
153+
num_freq,
150154
dtype=torch.float64)[..., None]
151155

152-
complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)
156+
stretched = F.phase_vocoder(
157+
torch.view_as_real(spec) if test_pseudo_complex else spec,
158+
rate=rate, phase_advance=phase_advance)
153159

154-
# == Test shape
155-
expected_size = list(complex_specgrams.size())
156-
expected_size[-2] = int(np.ceil(expected_size[-2] / rate))
157-
158-
assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
159-
assert complex_specgrams_stretch.size() == torch.Size(expected_size)
160-
161-
# == Test values
162-
index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
163-
mono_complex_specgram = complex_specgrams[index].numpy()
164-
mono_complex_specgram = mono_complex_specgram[..., 0] + \
165-
mono_complex_specgram[..., 1] * 1j
166-
expected_complex_stretch = librosa.phase_vocoder(
167-
mono_complex_specgram,
160+
expected_stretched = librosa.phase_vocoder(
161+
spec.numpy(),
168162
rate=rate,
169163
hop_length=hop_length)
170164

171-
complex_stretch = complex_specgrams_stretch[index].numpy()
172-
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]
173-
174-
self.assertEqual(complex_stretch, torch.from_numpy(expected_complex_stretch), atol=1e-5, rtol=1e-5)
165+
self.assertEqual(
166+
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
167+
torch.from_numpy(expected_stretched))

test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from torchaudio_unittest.common_utils import PytorchTestCase
4-
from .torchscript_consistency_impl import Functional
4+
from .torchscript_consistency_impl import Functional, FunctionalComplex
55

66

77
class TestFunctionalFloat32(Functional, PytorchTestCase):
@@ -12,3 +12,15 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
1212
class TestFunctionalFloat64(Functional, PytorchTestCase):
1313
dtype = torch.float64
1414
device = torch.device('cpu')
15+
16+
17+
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
18+
complex_dtype = torch.complex64
19+
real_dtype = torch.float32
20+
device = torch.device('cpu')
21+
22+
23+
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
24+
complex_dtype = torch.complex128
25+
real_dtype = torch.float64
26+
device = torch.device('cpu')

test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
4-
from .torchscript_consistency_impl import Functional
4+
from .torchscript_consistency_impl import Functional, FunctionalComplex
55

66

77
@skipIfNoCuda
@@ -14,3 +14,17 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
1414
class TestFunctionalFloat64(Functional, PytorchTestCase):
1515
dtype = torch.float64
1616
device = torch.device('cuda')
17+
18+
19+
@skipIfNoCuda
20+
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
21+
complex_dtype = torch.complex64
22+
real_dtype = torch.float32
23+
device = torch.device('cuda')
24+
25+
26+
@skipIfNoCuda
27+
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
28+
complex_dtype = torch.complex128
29+
real_dtype = torch.float64
30+
device = torch.device('cuda')

test/torchaudio_unittest/functional/torchscript_consistency_impl.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import torchaudio.functional as F
6+
from parameterized import parameterized
67

78
from torchaudio_unittest import common_utils
89
from torchaudio_unittest.common_utils import (
@@ -551,21 +552,6 @@ def func(tensor):
551552
tensor = common_utils.get_whitenoise(sample_rate=44100)
552553
self._assert_consistency(func, tensor)
553554

554-
def test_phase_vocoder(self):
555-
def func(tensor, device: torch.device = self.device):
556-
rate = 0.5
557-
hop_length = 256
558-
phase_advance = torch.linspace(
559-
0,
560-
3.14 * hop_length,
561-
tensor.shape[-3],
562-
dtype=torch.float64,
563-
).to(device)[..., None]
564-
return F.phase_vocoder(tensor, rate, phase_advance)
565-
566-
tensor = torch.randn(2, 1025, 400, 2)
567-
self._assert_consistency(func, tensor)
568-
569555
@common_utils.skipIfNoKaldi
570556
def test_compute_kaldi_pitch(self):
571557
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
@@ -577,3 +563,40 @@ def func(tensor):
577563

578564
tensor = common_utils.get_whitenoise(sample_rate=44100)
579565
self._assert_consistency(func, tensor)
566+
567+
568+
class FunctionalComplex:
569+
complex_dtype = None
570+
real_dtype = None
571+
device = None
572+
573+
def _assert_consistency(self, func, tensor, test_pseudo_complex=False):
574+
assert tensor.is_complex()
575+
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
576+
ts_func = torch.jit.script(func)
577+
578+
if test_pseudo_complex:
579+
tensor = torch.view_as_real(tensor)
580+
output = func(tensor)
581+
ts_output = ts_func(tensor)
582+
self.assertEqual(ts_output, output)
583+
584+
@parameterized.expand([(True, ), (False, )])
585+
def test_phase_vocoder(self, test_paseudo_complex):
586+
def func(tensor):
587+
is_complex = tensor.is_complex()
588+
589+
n_freq = tensor.size(-2 if is_complex else -3)
590+
rate = 0.5
591+
hop_length = 256
592+
phase_advance = torch.linspace(
593+
0,
594+
3.14 * hop_length,
595+
n_freq,
596+
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
597+
device=tensor.device,
598+
)[..., None]
599+
return F.phase_vocoder(tensor, rate, phase_advance)
600+
601+
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
602+
self._assert_consistency(func, tensor, test_paseudo_complex)

test/torchaudio_unittest/transforms/batch_consistency_test.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Test numerical consistency among single input and batched input."""
22
import torch
33
import torchaudio
4+
from parameterized import parameterized
45

56
from torchaudio_unittest import common_utils
67

@@ -130,40 +131,31 @@ def test_batch_mfcc(self):
130131
computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
131132
self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
132133

133-
def test_batch_TimeStretch(self):
134-
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
135-
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
136-
134+
@parameterized.expand([(True, ), (False, )])
135+
def test_batch_TimeStretch(self, test_pseudo_complex):
137136
rate = 2
137+
num_freq = 1025
138+
num_frames = 400
138139

139-
complex_specgrams = torch.view_as_real(
140-
torch.stft(
141-
input=waveform,
142-
n_fft=2048,
143-
hop_length=512,
144-
win_length=2048,
145-
window=torch.hann_window(2048),
146-
center=True,
147-
pad_mode='reflect',
148-
normalized=True,
149-
onesided=True,
150-
return_complex=True,
151-
)
152-
)
140+
spec = torch.randn(num_freq, num_frames, dtype=torch.complex64)
141+
pattern = [3, 1, 1, 1]
142+
if test_pseudo_complex:
143+
spec = torch.view_as_real(spec)
144+
pattern += [1]
153145

154146
# Single then transform then batch
155147
expected = torchaudio.transforms.TimeStretch(
156148
fixed_rate=rate,
157-
n_freq=1025,
149+
n_freq=num_freq,
158150
hop_length=512,
159-
)(complex_specgrams).repeat(3, 1, 1, 1, 1)
151+
)(spec).repeat(*pattern)
160152

161153
# Batch then transform
162154
computed = torchaudio.transforms.TimeStretch(
163155
fixed_rate=rate,
164-
n_freq=1025,
156+
n_freq=num_freq,
165157
hop_length=512,
166-
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
158+
)(spec.repeat(*pattern))
167159

168160
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
169161

test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from torchaudio_unittest.common_utils import PytorchTestCase
4-
from .torchscript_consistency_impl import Transforms
4+
from .torchscript_consistency_impl import Transforms, TransformsComplex
55

66

77
class TestTransformsFloat32(Transforms, PytorchTestCase):
@@ -12,3 +12,15 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
1212
class TestTransformsFloat64(Transforms, PytorchTestCase):
1313
dtype = torch.float64
1414
device = torch.device('cpu')
15+
16+
17+
class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
18+
complex_dtype = torch.complex64
19+
real_dtype = torch.float32
20+
device = torch.device('cpu')
21+
22+
23+
class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
24+
complex_dtype = torch.complex128
25+
real_dtype = torch.float64
26+
device = torch.device('cpu')

0 commit comments

Comments
 (0)