Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions test/kaldi_compatibility_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import unittest
import subprocess
import math

import kaldi_io
import torch
Expand Down Expand Up @@ -107,3 +108,89 @@ def test_mfcc(self, kwargs):
command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)

def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error
input = torch.empty(0).to(dtype=self.dtype, device=self.device)
self.assertRaises(AssertionError, torchaudio.compliance.kaldi.mfcc, input)

def test_resample_waveform_upsample_size(self):
test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
sound, sample_rate = torchaudio.load_wav(test_8000_filepath)
sound = sound.to(dtype=self.dtype, device=self.device)
upsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate * 2)
self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2)

def test_resample_waveform_downsample_size(self):
test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
sound, sample_rate = torchaudio.load_wav(test_8000_filepath)
sound = sound.to(dtype=self.dtype, device=self.device)
downsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate // 2)
self.assertTrue(downsample_sound.size(-1) == sound.size(-1) // 2)

def test_resample_waveform_identity_size(self):
test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
sound, sample_rate = torchaudio.load_wav(test_8000_filepath)
sound = sound.to(dtype=self.dtype, device=self.device)
downsample_sound = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, sample_rate)
self.assertTrue(downsample_sound.size(-1) == sound.size(-1))

def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
atol=1e-1, rtol=1e-4):
# resample the signal and compare it to the ground truth
n_to_trim = 20
sample_rate = 1000
new_sample_rate = sample_rate

if up_scale_factor is not None:
new_sample_rate *= up_scale_factor

if down_scale_factor is not None:
new_sample_rate //= down_scale_factor

duration = 5 # seconds
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
original_timestamps = original_timestamps.to(dtype=self.dtype, device=self.device)

sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = torchaudio.compliance.kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze()

new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
new_timestamps = new_timestamps.to(dtype=self.dtype, device=self.device)
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)

# trim the first/last n samples as these points have boundary effects
ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
estimate = estimate[..., n_to_trim:-n_to_trim]

self.assert_equal(estimate, expected=ground_truth, atol=atol, rtol=rtol)

def test_resample_waveform_downsample_accuracy(self):
for i in range(1, 20):
self._test_resample_waveform_accuracy(down_scale_factor=i * 2)

def test_resample_waveform_upsample_accuracy(self):
for i in range(1, 20):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0)

def test_resample_waveform_multi_channel(self):
num_channels = 3

test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
# (1, 8000)
sound, sample_rate = torchaudio.load_wav(test_8000_filepath)
sound = sound.to(dtype=self.dtype, device=self.device)
multi_sound = sound.repeat(num_channels, 1) # (num_channels, 8000)

for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5

multi_sound_sampled = torchaudio.compliance.kaldi.resample_waveform(multi_sound, sample_rate, sample_rate // 2)

# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = sound * (i + 1) * 1.5
single_channel_sampled = torchaudio.compliance.kaldi.resample_waveform(single_channel,
sample_rate,
sample_rate // 2)
self.assert_equal(multi_sound_sampled[i, :], expected=single_channel_sampled[0], rtol=1e-4, atol=1e-8)