From 8bbb80482b35a801008a78358cb680e02fafbb20 Mon Sep 17 00:00:00 2001 From: Bhargav Kathivarapu Date: Fri, 5 Jun 2020 09:52:24 +0530 Subject: [PATCH 1/3] Migrate kaldi resample accuracy and size tests Signed-off-by: Bhargav Kathivarapu --- test/kaldi_compatibility_impl.py | 79 ++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/test/kaldi_compatibility_impl.py b/test/kaldi_compatibility_impl.py index 9e0ecfe1ca..0f85699070 100644 --- a/test/kaldi_compatibility_impl.py +++ b/test/kaldi_compatibility_impl.py @@ -3,6 +3,7 @@ import shutil import unittest import subprocess +import math import kaldi_io import torch @@ -107,3 +108,81 @@ def test_mfcc(self, kwargs): command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] kaldi_result = _run_kaldi(command, 'scp', wave_file) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) + + def test_mfcc_empty(self): + # Passing in an empty tensor should result in an error + self.assertRaises(AssertionError, torchaudio.compliance.kaldi.mfcc, torch.empty(0)) + + def test_resample_waveform_upsample_size(self): + test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') + sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + upsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate * 2) + self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2) + + def test_resample_waveform_downsample_size(self): + test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') + sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + downsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate // 2) + self.assertTrue(downsample_sound.size(-1) == sound.size(-1) // 2) + + def test_resample_waveform_identity_size(self): + test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') + sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + downsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate) + self.assertTrue(downsample_sound.size(-1) == sound.size(-1)) + + def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None, + atol=1e-1, rtol=1e-4): + # resample the signal and compare it to the ground truth + n_to_trim = 20 + sample_rate = 1000 + new_sample_rate = sample_rate + + if up_scale_factor is not None: + new_sample_rate *= up_scale_factor + + if down_scale_factor is not None: + new_sample_rate //= down_scale_factor + + duration = 5 # seconds + original_timestamps = torch.arange(0, duration, 1.0 / sample_rate) + + sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0) + estimate = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze() + + new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)] + ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps) + + # trim the first/last n samples as these points have boundary effects + ground_truth = ground_truth[..., n_to_trim:-n_to_trim] + estimate = estimate[..., n_to_trim:-n_to_trim] + + torch.assert_equal(estimate, ground_truth, atol=atol, rtol=rtol) + + def test_resample_waveform_downsample_accuracy(self): + for i in range(1, 20): + self._test_resample_waveform_accuracy(down_scale_factor=i * 2) + + def test_resample_waveform_upsample_accuracy(self): + for i in range(1, 20): + self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0) + + def test_resample_waveform_multi_channel(self): + num_channels = 3 + + test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') + sound, sample_rate = torchaudio.load_wav(test_8000_filepath) # (1, 8000) + multi_sound = sound.repeat(num_channels, 1) # (num_channels, 8000) + + for i in range(num_channels): + multi_sound[i, :] *= (i + 1) * 1.5 + + multi_sound_sampled = torchaudio.compliance.kaldi.resample_waveform(multi_sound, sample_rate, sample_rate // 2) + + # check that sampling is same whether using separately or in a tensor of size (c, n) + for i in range(num_channels): + single_channel = sound * (i + 1) * 1.5 + single_channel_sampled = torchaudio.compliance.kaldi.resample_waveform(single_channel, + sample_rate, + sample_rate // 2) + self.assert_equal(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-8) From 94cf6e952f76334feb6b3d3ed60c6f1652c3421b Mon Sep 17 00:00:00 2001 From: Bhargav Kathivarapu Date: Fri, 5 Jun 2020 10:15:48 +0530 Subject: [PATCH 2/3] Minor fix Signed-off-by: Bhargav Kathivarapu --- test/kaldi_compatibility_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/kaldi_compatibility_impl.py b/test/kaldi_compatibility_impl.py index 0f85699070..52c3fa94f3 100644 --- a/test/kaldi_compatibility_impl.py +++ b/test/kaldi_compatibility_impl.py @@ -157,7 +157,7 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact ground_truth = ground_truth[..., n_to_trim:-n_to_trim] estimate = estimate[..., n_to_trim:-n_to_trim] - torch.assert_equal(estimate, ground_truth, atol=atol, rtol=rtol) + self.assert_equal(estimate, expected=ground_truth, atol=atol, rtol=rtol) def test_resample_waveform_downsample_accuracy(self): for i in range(1, 20): @@ -185,4 +185,4 @@ def test_resample_waveform_multi_channel(self): single_channel_sampled = torchaudio.compliance.kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2) - self.assert_equal(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-8) + self.assert_equal(multi_sound_sampled[i, :], expected=single_channel_sampled[0], rtol=1e-4, atol=1e-8) From 3599ef5201426b4b7bf17f84d880da996b8c5877 Mon Sep 17 00:00:00 2001 From: Bhargav Kathivarapu Date: Fri, 5 Jun 2020 11:11:29 +0530 Subject: [PATCH 3/3] Add device and dtype to tests Signed-off-by: Bhargav Kathivarapu --- test/kaldi_compatibility_impl.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/kaldi_compatibility_impl.py b/test/kaldi_compatibility_impl.py index 52c3fa94f3..f4c7986269 100644 --- a/test/kaldi_compatibility_impl.py +++ b/test/kaldi_compatibility_impl.py @@ -111,23 +111,27 @@ def test_mfcc(self, kwargs): def test_mfcc_empty(self): # Passing in an empty tensor should result in an error - self.assertRaises(AssertionError, torchaudio.compliance.kaldi.mfcc, torch.empty(0)) + input = torch.empty(0).to(dtype=self.dtype, device=self.device) + self.assertRaises(AssertionError, torchaudio.compliance.kaldi.mfcc, input) def test_resample_waveform_upsample_size(self): test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + sound = sound.to(dtype=self.dtype, device=self.device) upsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate * 2) self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2) def test_resample_waveform_downsample_size(self): test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + sound = sound.to(dtype=self.dtype, device=self.device) downsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate // 2) self.assertTrue(downsample_sound.size(-1) == sound.size(-1) // 2) def test_resample_waveform_identity_size(self): test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + sound = sound.to(dtype=self.dtype, device=self.device) downsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate) self.assertTrue(downsample_sound.size(-1) == sound.size(-1)) @@ -146,11 +150,13 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact duration = 5 # seconds original_timestamps = torch.arange(0, duration, 1.0 / sample_rate) + original_timestamps = original_timestamps.to(dtype=self.dtype, device=self.device) sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0) estimate = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze() new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)] + new_timestamps = new_timestamps.to(dtype=self.dtype, device=self.device) ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps) # trim the first/last n samples as these points have boundary effects @@ -171,7 +177,9 @@ def test_resample_waveform_multi_channel(self): num_channels = 3 test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav') - sound, sample_rate = torchaudio.load_wav(test_8000_filepath) # (1, 8000) + # (1, 8000) + sound, sample_rate = torchaudio.load_wav(test_8000_filepath) + sound = sound.to(dtype=self.dtype, device=self.device) multi_sound = sound.repeat(num_channels, 1) # (num_channels, 8000) for i in range(num_channels):