11"""Test suites for checking numerical compatibility against Kaldi"""
2- import subprocess
3-
4- import kaldi_io
52import torch
63import torchaudio .functional as F
74import torchaudio .compliance .kaldi
1411 skipIfNoExec ,
1512 get_asset_path ,
1613 load_wav ,
17- save_wav ,
18- get_sinusoid ,
14+ )
15+ from torchaudio_unittest .common_utils .kaldi_utils import (
16+ convert_args ,
17+ run_kaldi ,
1918)
2019
2120
22- def _convert_args (** kwargs ):
23- args = []
24- for key , value in kwargs .items ():
25- if key == 'sample_rate' :
26- key = 'sample_frequency'
27- key = '--' + key .replace ('_' , '-' )
28- value = str (value ).lower () if value in [True , False ] else str (value )
29- args .append ('%s=%s' % (key , value ))
30- return args
31-
32-
33- def _run_kaldi (command , input_type , input_value ):
34- """Run provided Kaldi command, pass a tensor and get the resulting tensor
35-
36- Args:
37- input_type: str
38- 'ark' or 'scp'
39- input_value:
40- Tensor for 'ark'
41- string for 'scp' (path to an audio file)
42- """
43- key = 'foo'
44- process = subprocess .Popen (command , stdin = subprocess .PIPE , stdout = subprocess .PIPE )
45- if input_type == 'ark' :
46- kaldi_io .write_mat (process .stdin , input_value .cpu ().numpy (), key = key )
47- elif input_type == 'scp' :
48- process .stdin .write (f'{ key } { input_value } ' .encode ('utf8' ))
49- else :
50- raise NotImplementedError ('Unexpected type' )
51- process .stdin .close ()
52- result = dict (kaldi_io .read_mat_ark (process .stdout ))['foo' ]
53- return torch .from_numpy (result .copy ()) # copy supresses some torch warning
54-
55-
56- class KaldiTestBase (TempDirMixin , TestBaseMixin ):
21+ class Kaldi (TempDirMixin , TestBaseMixin ):
5722 def assert_equal (self , output , * , expected , rtol = None , atol = None ):
5823 expected = expected .to (dtype = self .dtype , device = self .device )
5924 self .assertEqual (output , expected , rtol = rtol , atol = atol )
6025
61-
62- class Kaldi (KaldiTestBase ):
6326 @skipIfNoExec ('apply-cmvn-sliding' )
6427 def test_sliding_window_cmn (self ):
6528 """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
@@ -72,8 +35,8 @@ def test_sliding_window_cmn(self):
7235
7336 tensor = torch .randn (40 , 10 , dtype = self .dtype , device = self .device )
7437 result = F .sliding_window_cmn (tensor , ** kwargs )
75- command = ['apply-cmvn-sliding' ] + _convert_args (** kwargs ) + ['ark:-' , 'ark:-' ]
76- kaldi_result = _run_kaldi (command , 'ark' , tensor )
38+ command = ['apply-cmvn-sliding' ] + convert_args (** kwargs ) + ['ark:-' , 'ark:-' ]
39+ kaldi_result = run_kaldi (command , 'ark' , tensor )
7740 self .assert_equal (result , expected = kaldi_result )
7841
7942 @parameterized .expand (load_params ('kaldi_test_fbank_args.json' ))
@@ -83,8 +46,8 @@ def test_fbank(self, kwargs):
8346 wave_file = get_asset_path ('kaldi_file.wav' )
8447 waveform = load_wav (wave_file , normalize = False )[0 ].to (dtype = self .dtype , device = self .device )
8548 result = torchaudio .compliance .kaldi .fbank (waveform , ** kwargs )
86- command = ['compute-fbank-feats' ] + _convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
87- kaldi_result = _run_kaldi (command , 'scp' , wave_file )
49+ command = ['compute-fbank-feats' ] + convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
50+ kaldi_result = run_kaldi (command , 'scp' , wave_file )
8851 self .assert_equal (result , expected = kaldi_result , rtol = 1e-4 , atol = 1e-8 )
8952
9053 @parameterized .expand (load_params ('kaldi_test_spectrogram_args.json' ))
@@ -94,8 +57,8 @@ def test_spectrogram(self, kwargs):
9457 wave_file = get_asset_path ('kaldi_file.wav' )
9558 waveform = load_wav (wave_file , normalize = False )[0 ].to (dtype = self .dtype , device = self .device )
9659 result = torchaudio .compliance .kaldi .spectrogram (waveform , ** kwargs )
97- command = ['compute-spectrogram-feats' ] + _convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
98- kaldi_result = _run_kaldi (command , 'scp' , wave_file )
60+ command = ['compute-spectrogram-feats' ] + convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
61+ kaldi_result = run_kaldi (command , 'scp' , wave_file )
9962 self .assert_equal (result , expected = kaldi_result , rtol = 1e-4 , atol = 1e-8 )
10063
10164 @parameterized .expand (load_params ('kaldi_test_mfcc_args.json' ))
@@ -105,24 +68,6 @@ def test_mfcc(self, kwargs):
10568 wave_file = get_asset_path ('kaldi_file.wav' )
10669 waveform = load_wav (wave_file , normalize = False )[0 ].to (dtype = self .dtype , device = self .device )
10770 result = torchaudio .compliance .kaldi .mfcc (waveform , ** kwargs )
108- command = ['compute-mfcc-feats' ] + _convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
109- kaldi_result = _run_kaldi (command , 'scp' , wave_file )
71+ command = ['compute-mfcc-feats' ] + convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
72+ kaldi_result = run_kaldi (command , 'scp' , wave_file )
11073 self .assert_equal (result , expected = kaldi_result , rtol = 1e-4 , atol = 1e-8 )
111-
112-
113- class KaldiCPUOnly (KaldiTestBase ):
114- @parameterized .expand (load_params ('kaldi_test_pitch_args.json' ))
115- @skipIfNoExec ('compute-kaldi-pitch-feats' )
116- def test_pitch_feats (self , kwargs ):
117- """compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
118- sample_rate = kwargs ['sample_rate' ]
119- waveform = get_sinusoid (dtype = 'float32' , sample_rate = sample_rate )
120- result = F .compute_kaldi_pitch (waveform [0 ], ** kwargs )
121-
122- waveform = get_sinusoid (dtype = 'int16' , sample_rate = sample_rate )
123- wave_file = self .get_temp_path ('test.wav' )
124- save_wav (wave_file , waveform , sample_rate )
125-
126- command = ['compute-kaldi-pitch-feats' ] + _convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
127- kaldi_result = _run_kaldi (command , 'scp' , wave_file )
128- self .assert_equal (result , expected = kaldi_result )
0 commit comments