Skip to content

Commit 76bf776

Browse files
author
Caroline Chen
committed
add kaiser window
1 parent 52e7bfd commit 76bf776

File tree

5 files changed

+64
-33
lines changed

5 files changed

+64
-33
lines changed

test/torchaudio_unittest/compliance_kaldi_test.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from torchaudio_unittest import common_utils
99
from .compliance import utils as compliance_utils
10+
from parameterized import parameterized
1011

1112

1213
def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
@@ -182,20 +183,26 @@ def get_output_fn(sound, args):
182183

183184
self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)
184185

185-
def test_resample_waveform_upsample_size(self):
186-
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2)
186+
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
187+
def test_resample_waveform_upsample_size(self, resampling_method):
188+
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2,
189+
resampling_method=resampling_method)
187190
self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)
188191

189-
def test_resample_waveform_downsample_size(self):
190-
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2)
192+
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
193+
def test_resample_waveform_downsample_size(self, resampling_method):
194+
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2,
195+
resampling_method=resampling_method)
191196
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)
192197

193-
def test_resample_waveform_identity_size(self):
194-
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr)
198+
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
199+
def test_resample_waveform_identity_size(self, resampling_method):
200+
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr,
201+
resampling_method=resampling_method)
195202
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1))
196203

197204
def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
198-
atol=1e-1, rtol=1e-4):
205+
resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
199206
# resample the signal and compare it to the ground truth
200207
n_to_trim = 20
201208
sample_rate = 1000
@@ -211,7 +218,8 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
211218
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
212219

213220
sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
214-
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze()
221+
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate,
222+
resampling_method=resampling_method).squeeze()
215223

216224
new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
217225
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
@@ -222,27 +230,32 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
222230

223231
self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)
224232

225-
def test_resample_waveform_downsample_accuracy(self):
233+
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
234+
def test_resample_waveform_downsample_accuracy(self, resampling_method):
226235
for i in range(1, 20):
227-
self._test_resample_waveform_accuracy(down_scale_factor=i * 2)
236+
self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)
228237

229-
def test_resample_waveform_upsample_accuracy(self):
238+
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
239+
def test_resample_waveform_upsample_accuracy(self, resampling_method):
230240
for i in range(1, 20):
231-
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0)
241+
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
232242

233-
def test_resample_waveform_multi_channel(self):
243+
@parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
244+
def test_resample_waveform_multi_channel(self, resampling_method):
234245
num_channels = 3
235246

236247
multi_sound = self.test1_signal.repeat(num_channels, 1) # (num_channels, 8000 smp)
237248

238249
for i in range(num_channels):
239250
multi_sound[i, :] *= (i + 1) * 1.5
240251

241-
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2)
252+
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2,
253+
resampling_method=resampling_method)
242254

243255
# check that sampling is same whether using separately or in a tensor of size (c, n)
244256
for i in range(num_channels):
245257
single_channel = self.test1_signal * (i + 1) * 1.5
246258
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
247-
self.test1_signal_sr // 2)
259+
self.test1_signal_sr // 2,
260+
resampling_method=resampling_method)
248261
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)

test/torchaudio_unittest/transforms/transforms_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,11 @@ def test_resample_size(self):
169169

170170
upsample_rate = sample_rate * 2
171171
downsample_rate = sample_rate // 2
172-
invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')
172+
invalid_resampling_method = 'foo'
173173

174-
self.assertRaises(ValueError, invalid_resample, waveform)
174+
with self.assertRaises(ValueError):
175+
torchaudio.transforms.Resample(sample_rate, upsample_rate,
176+
resampling_method=invalid_resampling_method)
175177

176178
upsample_resample = torchaudio.transforms.Resample(
177179
sample_rate, upsample_rate, resampling_method='sinc_interpolation')

torchaudio/compliance/kaldi.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,8 @@ def resample_waveform(waveform: Tensor,
756756
orig_freq: float,
757757
new_freq: float,
758758
lowpass_filter_width: int = 6,
759-
rolloff: float = 0.99) -> Tensor:
759+
rolloff: float = 0.99,
760+
resampling_method: str = "sinc_interpolation") -> Tensor:
760761
r"""Resamples the waveform at the new frequency.
761762
762763
This is a wrapper around ``torchaudio.functional.resample``.
@@ -773,4 +774,5 @@ def resample_waveform(waveform: Tensor,
773774
Returns:
774775
Tensor: The waveform at the new frequency
775776
"""
776-
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width, rolloff)
777+
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width,
778+
rolloff, resampling_method)

torchaudio/functional/functional.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,9 @@ def _get_sinc_resample_kernel(
13031303
new_freq: float,
13041304
gcd: int,
13051305
lowpass_filter_width: int,
1306-
rolloff: float):
1306+
rolloff: float,
1307+
resampling_method: str,
1308+
beta: float = 6.):
13071309

13081310
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
13091311
warnings.warn(
@@ -1352,15 +1354,20 @@ def _get_sinc_resample_kernel(
13521354
# they will have a lot of almost zero values to the left or to the right...
13531355
# There is probably a way to evaluate those filters more efficiently, but this is kept for
13541356
# future work.
1355-
idx = torch.arange(-width, width + orig_freq)
1357+
idx = torch.arange(-width, width + orig_freq, dtype=torch.float64)
13561358

13571359
for i in range(new_freq):
13581360
t = (-i / new_freq + idx / orig_freq) * base_freq
13591361
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
1360-
t *= math.pi
1361-
# we do not use torch.hann_window here as we need to evaluate the window
1362+
1363+
# we do not use built in torch windows here as we need to evaluate the window
13621364
# at specific positions, not over a regular grid.
1363-
window = torch.cos(t / lowpass_filter_width / 2)**2
1365+
if resampling_method == "sinc_interpolation":
1366+
window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2
1367+
elif resampling_method == "kaiser_window":
1368+
beta = torch.tensor(beta, dtype=float)
1369+
window = torch.i0(beta * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta)
1370+
t *= math.pi
13641371
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
13651372
kernel.mul_(window)
13661373
kernels.append(kernel)
@@ -1403,6 +1410,8 @@ def resample(
14031410
new_freq: float,
14041411
lowpass_filter_width: int = 6,
14051412
rolloff: float = 0.99,
1413+
resampling_method: str = "sinc_interpolation",
1414+
**kwargs,
14061415
) -> Tensor:
14071416
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
14081417
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
@@ -1421,6 +1430,7 @@ def resample(
14211430
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
14221431
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
14231432
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
1433+
resampling_method (str, optional):
14241434
14251435
Returns:
14261436
Tensor: The waveform at the new frequency of dimension (..., time).
@@ -1433,6 +1443,7 @@ def resample(
14331443

14341444
gcd = math.gcd(int(orig_freq), int(new_freq))
14351445

1436-
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff)
1446+
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff,
1447+
resampling_method, **kwargs)
14371448
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
14381449
return resampled

torchaudio/transforms.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,8 @@ class Resample(torch.nn.Module):
657657
Args:
658658
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
659659
new_freq (float, optional): The desired frequency. (Default: ``16000``)
660-
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
660+
resampling_method (str, optional): The resampling method.
661+
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
661662
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
662663
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
663664
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
@@ -669,7 +670,8 @@ def __init__(self,
669670
new_freq: float = 16000,
670671
resampling_method: str = 'sinc_interpolation',
671672
lowpass_filter_width: int = 6,
672-
rolloff: float = 0.99) -> None:
673+
rolloff: float = 0.99,
674+
**kwargs) -> None:
673675
super(Resample, self).__init__()
674676

675677
self.orig_freq = orig_freq
@@ -679,8 +681,12 @@ def __init__(self,
679681
self.lowpass_filter_width = lowpass_filter_width
680682
self.rolloff = rolloff
681683

684+
if self.resampling_method not in ['sinc_interpolation', 'kaiser_window']:
685+
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
686+
682687
self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd,
683-
self.lowpass_filter_width, self.rolloff)
688+
self.lowpass_filter_width, self.rolloff,
689+
self.resampling_method, **kwargs)
684690

685691
def forward(self, waveform: Tensor) -> Tensor:
686692
r"""
@@ -690,11 +696,8 @@ def forward(self, waveform: Tensor) -> Tensor:
690696
Returns:
691697
Tensor: Output signal of dimension (..., time).
692698
"""
693-
if self.resampling_method == 'sinc_interpolation':
694-
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
695-
self.kernel, self.width)
696-
697-
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
699+
return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd,
700+
self.kernel, self.width)
698701

699702

700703
class ComplexNorm(torch.nn.Module):

0 commit comments

Comments
 (0)