Skip to content

Commit b158e7c

Browse files
committed
fixup! Add README
1 parent 3862160 commit b158e7c

File tree

5 files changed

+97
-74
lines changed

5 files changed

+97
-74
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import subprocess
2+
3+
import torch
4+
5+
6+
def convert_args(**kwargs):
7+
args = []
8+
for key, value in kwargs.items():
9+
if key == 'sample_rate':
10+
key = 'sample_frequency'
11+
key = '--' + key.replace('_', '-')
12+
value = str(value).lower() if value in [True, False] else str(value)
13+
args.append('%s=%s' % (key, value))
14+
return args
15+
16+
17+
def run_kaldi(command, input_type, input_value):
18+
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
19+
20+
Args:
21+
input_type: str
22+
'ark' or 'scp'
23+
input_value:
24+
Tensor for 'ark'
25+
string for 'scp' (path to an audio file)
26+
"""
27+
import kaldi_io
28+
29+
key = 'foo'
30+
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
31+
if input_type == 'ark':
32+
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
33+
elif input_type == 'scp':
34+
process.stdin.write(f'{key} {input_value}'.encode('utf8'))
35+
else:
36+
raise NotImplementedError('Unexpected type')
37+
process.stdin.close()
38+
result = dict(kaldi_io.read_mat_ark(process.stdout))['foo']
39+
return torch.from_numpy(result.copy()) # copy supresses some torch warning
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torchaudio_unittest.common_utils import PytorchTestCase
2+
from .kaldi_compatibility_test_impl import KaldiCPUOnly
3+
4+
5+
class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase):
6+
dtype = torch.float32
7+
device = torch.device('cpu')
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from parameterized import parameterized
2+
import torchaudio.functional as F
3+
4+
from torchaudio_unittest.common_utils import (
5+
get_sinusoid,
6+
load_params,
7+
save_wav,
8+
skipIfNoExec,
9+
TempDirMixin,
10+
TestBaseMixin,
11+
)
12+
from torchaudio_unittest.common_utils.kaldi_utils import (
13+
convert_args,
14+
run_kaldi,
15+
)
16+
17+
18+
class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
19+
def assert_equal(self, output, *, expected, rtol=None, atol=None):
20+
expected = expected.to(dtype=self.dtype, device=self.device)
21+
self.assertEqual(output, expected, rtol=rtol, atol=atol)
22+
23+
@parameterized.expand(load_params('kaldi_test_pitch_args.json'))
24+
@skipIfNoExec('compute-kaldi-pitch-feats')
25+
def test_pitch_feats(self, kwargs):
26+
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
27+
sample_rate = kwargs['sample_rate']
28+
waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate)
29+
result = F.compute_kaldi_pitch(waveform[0], **kwargs)
30+
31+
waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate)
32+
wave_file = self.get_temp_path('test.wav')
33+
save_wav(wave_file, waveform, sample_rate)
34+
35+
command = ['compute-kaldi-pitch-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
36+
kaldi_result = run_kaldi(command, 'scp', wave_file)
37+
self.assert_equal(result, expected=kaldi_result)
Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from torchaudio_unittest import common_utils
4-
from .kaldi_compatibility_impl import Kaldi, KaldiCPUOnly
4+
from .kaldi_compatibility_impl import Kaldi
55

66

77
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
@@ -12,8 +12,3 @@ class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
1212
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
1313
dtype = torch.float64
1414
device = torch.device('cpu')
15-
16-
17-
class TestKaldiCPUOnly(KaldiCPUOnly, common_utils.PytorchTestCase):
18-
dtype = torch.float32
19-
device = torch.device('cpu')
Lines changed: 13 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
"""Test suites for checking numerical compatibility against Kaldi"""
2-
import subprocess
3-
4-
import kaldi_io
52
import torch
63
import torchaudio.functional as F
74
import torchaudio.compliance.kaldi
@@ -14,52 +11,18 @@
1411
skipIfNoExec,
1512
get_asset_path,
1613
load_wav,
17-
save_wav,
18-
get_sinusoid,
14+
)
15+
from torchaudio_unittest.common_utils.kaldi_utils import (
16+
convert_args,
17+
run_kaldi,
1918
)
2019

2120

22-
def _convert_args(**kwargs):
23-
args = []
24-
for key, value in kwargs.items():
25-
if key == 'sample_rate':
26-
key = 'sample_frequency'
27-
key = '--' + key.replace('_', '-')
28-
value = str(value).lower() if value in [True, False] else str(value)
29-
args.append('%s=%s' % (key, value))
30-
return args
31-
32-
33-
def _run_kaldi(command, input_type, input_value):
34-
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
35-
36-
Args:
37-
input_type: str
38-
'ark' or 'scp'
39-
input_value:
40-
Tensor for 'ark'
41-
string for 'scp' (path to an audio file)
42-
"""
43-
key = 'foo'
44-
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
45-
if input_type == 'ark':
46-
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
47-
elif input_type == 'scp':
48-
process.stdin.write(f'{key} {input_value}'.encode('utf8'))
49-
else:
50-
raise NotImplementedError('Unexpected type')
51-
process.stdin.close()
52-
result = dict(kaldi_io.read_mat_ark(process.stdout))['foo']
53-
return torch.from_numpy(result.copy()) # copy supresses some torch warning
54-
55-
56-
class KaldiTestBase(TempDirMixin, TestBaseMixin):
21+
class Kaldi(TempDirMixin, TestBaseMixin):
5722
def assert_equal(self, output, *, expected, rtol=None, atol=None):
5823
expected = expected.to(dtype=self.dtype, device=self.device)
5924
self.assertEqual(output, expected, rtol=rtol, atol=atol)
6025

61-
62-
class Kaldi(KaldiTestBase):
6326
@skipIfNoExec('apply-cmvn-sliding')
6427
def test_sliding_window_cmn(self):
6528
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
@@ -72,8 +35,8 @@ def test_sliding_window_cmn(self):
7235

7336
tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
7437
result = F.sliding_window_cmn(tensor, **kwargs)
75-
command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-']
76-
kaldi_result = _run_kaldi(command, 'ark', tensor)
38+
command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
39+
kaldi_result = run_kaldi(command, 'ark', tensor)
7740
self.assert_equal(result, expected=kaldi_result)
7841

7942
@parameterized.expand(load_params('kaldi_test_fbank_args.json'))
@@ -83,8 +46,8 @@ def test_fbank(self, kwargs):
8346
wave_file = get_asset_path('kaldi_file.wav')
8447
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
8548
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
86-
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
87-
kaldi_result = _run_kaldi(command, 'scp', wave_file)
49+
command = ['compute-fbank-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
50+
kaldi_result = run_kaldi(command, 'scp', wave_file)
8851
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
8952

9053
@parameterized.expand(load_params('kaldi_test_spectrogram_args.json'))
@@ -94,8 +57,8 @@ def test_spectrogram(self, kwargs):
9457
wave_file = get_asset_path('kaldi_file.wav')
9558
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
9659
result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
97-
command = ['compute-spectrogram-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
98-
kaldi_result = _run_kaldi(command, 'scp', wave_file)
60+
command = ['compute-spectrogram-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
61+
kaldi_result = run_kaldi(command, 'scp', wave_file)
9962
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
10063

10164
@parameterized.expand(load_params('kaldi_test_mfcc_args.json'))
@@ -105,24 +68,6 @@ def test_mfcc(self, kwargs):
10568
wave_file = get_asset_path('kaldi_file.wav')
10669
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
10770
result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
108-
command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
109-
kaldi_result = _run_kaldi(command, 'scp', wave_file)
71+
command = ['compute-mfcc-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
72+
kaldi_result = run_kaldi(command, 'scp', wave_file)
11073
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
111-
112-
113-
class KaldiCPUOnly(KaldiTestBase):
114-
@parameterized.expand(load_params('kaldi_test_pitch_args.json'))
115-
@skipIfNoExec('compute-kaldi-pitch-feats')
116-
def test_pitch_feats(self, kwargs):
117-
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
118-
sample_rate = kwargs['sample_rate']
119-
waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate)
120-
result = F.compute_kaldi_pitch(waveform[0], **kwargs)
121-
122-
waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate)
123-
wave_file = self.get_temp_path('test.wav')
124-
save_wav(wave_file, waveform, sample_rate)
125-
126-
command = ['compute-kaldi-pitch-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
127-
kaldi_result = _run_kaldi(command, 'scp', wave_file)
128-
self.assert_equal(result, expected=kaldi_result)

0 commit comments

Comments
 (0)