diff --git a/test/kaldi_compatibility_impl.py b/test/kaldi_compatibility_impl.py index 9e0ecfe1ca..f4c7986269 100644 --- a/test/kaldi_compatibility_impl.py +++ b/test/kaldi_compatibility_impl.py @@ -3,6 +3,7 @@ import shutil import unittest import subprocess +import math import kaldi_io import torch @@ -107,3 +108,89 @@ def test_mfcc(self, kwargs): command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] kaldi_result = _run_kaldi(command, 'scp', wave_file) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) + + def test_mfcc_empty(self): + # Passing in an empty tensor should result in an error + input = torch.empty(0).to(dtype=self.dtype, device=self.device) + self.assertRaises(AssertionError, torchaudio.compliance.kaldi.mfcc, input) + + def test_resample_waveform_upsample_size(self): + test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') + sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + sound = sound.to(dtype=self.dtype, device=self.device) + upsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate * 2) + self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2) + + def test_resample_waveform_downsample_size(self): + test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') + sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + sound = sound.to(dtype=self.dtype, device=self.device) + downsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate // 2) + self.assertTrue(downsample_sound.size(-1) == sound.size(-1) // 2) + + def test_resample_waveform_identity_size(self): + test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') + sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + sound = sound.to(dtype=self.dtype, device=self.device) + downsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate) + self.assertTrue(downsample_sound.size(-1) == sound.size(-1)) + + def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None, + atol=1e-1, rtol=1e-4): + # resample the signal and compare it to the ground truth + n_to_trim = 20 + sample_rate = 1000 + new_sample_rate = sample_rate + + if up_scale_factor is not None: + new_sample_rate *= up_scale_factor + + if down_scale_factor is not None: + new_sample_rate //= down_scale_factor + + duration = 5 # seconds + original_timestamps = torch.arange(0, duration, 1.0 / sample_rate) + original_timestamps = original_timestamps.to(dtype=self.dtype, device=self.device) + + sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0) + estimate = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze() + + new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)] + new_timestamps = new_timestamps.to(dtype=self.dtype, device=self.device) + ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps) + + # trim the first/last n samples as these points have boundary effects + ground_truth = ground_truth[..., n_to_trim:-n_to_trim] + estimate = estimate[..., n_to_trim:-n_to_trim] + + self.assert_equal(estimate, expected=ground_truth, atol=atol, rtol=rtol) + + def test_resample_waveform_downsample_accuracy(self): + for i in range(1, 20): + self._test_resample_waveform_accuracy(down_scale_factor=i * 2) + + def test_resample_waveform_upsample_accuracy(self): + for i in range(1, 20): + self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0) + + def test_resample_waveform_multi_channel(self): + num_channels = 3 + + test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') + # (1, 8000) + sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + sound = sound.to(dtype=self.dtype, device=self.device) + multi_sound = sound.repeat(num_channels, 1) # (num_channels, 8000) + + for i in range(num_channels): + multi_sound[i, :] *= (i + 1) * 1.5 + + multi_sound_sampled = torchaudio.compliance.kaldi.resample_waveform(multi_sound, sample_rate, sample_rate // 2) + + # check that sampling is same whether using separately or in a tensor of size (c, n) + for i in range(num_channels): + single_channel = sound * (i + 1) * 1.5 + single_channel_sampled = torchaudio.compliance.kaldi.resample_waveform(single_channel, + sample_rate, + sample_rate // 2) + self.assert_equal(multi_sound_sampled[i, :], expected=single_channel_sampled[0], rtol=1e-4, atol=1e-8)