55import torch
66import torchaudio .functional as F
77import torchaudio .compliance .kaldi
8-
9- from . import common_utils
10- from .common_utils import load_params
118from parameterized import parameterized
129
10+ from .common_utils import (
11+ TestBaseMixin ,
12+ load_params ,
13+ skipIfNoExec ,
14+ get_asset_path ,
15+ load_wav
16+ )
17+
1318
1419def _convert_args (** kwargs ):
1520 args = []
@@ -43,14 +48,12 @@ def _run_kaldi(command, input_type, input_value):
4348 return torch .from_numpy (result .copy ()) # copy supresses some torch warning
4449
4550
46- class Kaldi (common_utils .TestBaseMixin ):
47- backend = 'sox'
48-
51+ class Kaldi (TestBaseMixin ):
4952 def assert_equal (self , output , * , expected , rtol = None , atol = None ):
5053 expected = expected .to (dtype = self .dtype , device = self .device )
5154 self .assertEqual (output , expected , rtol = rtol , atol = atol )
5255
53- @common_utils . skipIfNoExec ('apply-cmvn-sliding' )
56+ @skipIfNoExec ('apply-cmvn-sliding' )
5457 def test_sliding_window_cmn (self ):
5558 """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
5659 kwargs = {
@@ -67,33 +70,33 @@ def test_sliding_window_cmn(self):
6770 self .assert_equal (result , expected = kaldi_result )
6871
6972 @parameterized .expand (load_params ('kaldi_test_fbank_args.json' ))
70- @common_utils . skipIfNoExec ('compute-fbank-feats' )
73+ @skipIfNoExec ('compute-fbank-feats' )
7174 def test_fbank (self , kwargs ):
7275 """fbank should be numerically compatible with compute-fbank-feats"""
73- wave_file = common_utils . get_asset_path ('kaldi_file.wav' )
74- waveform = torchaudio . load_wav (wave_file )[0 ].to (dtype = self .dtype , device = self .device )
76+ wave_file = get_asset_path ('kaldi_file.wav' )
77+ waveform = load_wav (wave_file , normalize = False )[0 ].to (dtype = self .dtype , device = self .device )
7578 result = torchaudio .compliance .kaldi .fbank (waveform , ** kwargs )
7679 command = ['compute-fbank-feats' ] + _convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
7780 kaldi_result = _run_kaldi (command , 'scp' , wave_file )
7881 self .assert_equal (result , expected = kaldi_result , rtol = 1e-4 , atol = 1e-8 )
7982
8083 @parameterized .expand (load_params ('kaldi_test_spectrogram_args.json' ))
81- @common_utils . skipIfNoExec ('compute-spectrogram-feats' )
84+ @skipIfNoExec ('compute-spectrogram-feats' )
8285 def test_spectrogram (self , kwargs ):
8386 """spectrogram should be numerically compatible with compute-spectrogram-feats"""
84- wave_file = common_utils . get_asset_path ('kaldi_file.wav' )
85- waveform = torchaudio . load_wav (wave_file )[0 ].to (dtype = self .dtype , device = self .device )
87+ wave_file = get_asset_path ('kaldi_file.wav' )
88+ waveform = load_wav (wave_file , normalize = False )[0 ].to (dtype = self .dtype , device = self .device )
8689 result = torchaudio .compliance .kaldi .spectrogram (waveform , ** kwargs )
8790 command = ['compute-spectrogram-feats' ] + _convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
8891 kaldi_result = _run_kaldi (command , 'scp' , wave_file )
8992 self .assert_equal (result , expected = kaldi_result , rtol = 1e-4 , atol = 1e-8 )
9093
9194 @parameterized .expand (load_params ('kaldi_test_mfcc_args.json' ))
92- @common_utils . skipIfNoExec ('compute-mfcc-feats' )
95+ @skipIfNoExec ('compute-mfcc-feats' )
9396 def test_mfcc (self , kwargs ):
9497 """mfcc should be numerically compatible with compute-mfcc-feats"""
95- wave_file = common_utils . get_asset_path ('kaldi_file.wav' )
96- waveform = torchaudio . load_wav (wave_file )[0 ].to (dtype = self .dtype , device = self .device )
98+ wave_file = get_asset_path ('kaldi_file.wav' )
99+ waveform = load_wav (wave_file , normalize = False )[0 ].to (dtype = self .dtype , device = self .device )
97100 result = torchaudio .compliance .kaldi .mfcc (waveform , ** kwargs )
98101 command = ['compute-mfcc-feats' ] + _convert_args (** kwargs ) + ['scp:-' , 'ark:-' ]
99102 kaldi_result = _run_kaldi (command , 'scp' , wave_file )
0 commit comments