Skip to content

Commit 60ae951

Browse files
committed
Replace load wav with scipy
1 parent a20da5e commit 60ae951

File tree

5 files changed

+29
-29
lines changed

5 files changed

+29
-29
lines changed

test/functional_cpu_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,6 @@ def test_linearity_of_istft4(self):
299299

300300

301301
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
302-
backend = 'default'
303-
304302
def test_pitch(self):
305303
test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav")
306304
test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav")
@@ -312,7 +310,7 @@ def test_pitch(self):
312310
]
313311

314312
for filename, freq_ref in tests:
315-
waveform, sample_rate = torchaudio.load(filename)
313+
waveform, sample_rate = common_utils.load_wav(filename)
316314

317315
freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
318316

test/kaldi_compatibility_impl.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@
55
import torch
66
import torchaudio.functional as F
77
import torchaudio.compliance.kaldi
8-
9-
from . import common_utils
10-
from .common_utils import load_params
118
from parameterized import parameterized
129

10+
from .common_utils import (
11+
TestBaseMixin,
12+
load_params,
13+
skipIfNoExec,
14+
get_asset_path,
15+
load_wav
16+
)
17+
1318

1419
def _convert_args(**kwargs):
1520
args = []
@@ -43,14 +48,12 @@ def _run_kaldi(command, input_type, input_value):
4348
return torch.from_numpy(result.copy()) # copy supresses some torch warning
4449

4550

46-
class Kaldi(common_utils.TestBaseMixin):
47-
backend = 'sox'
48-
51+
class Kaldi(TestBaseMixin):
4952
def assert_equal(self, output, *, expected, rtol=None, atol=None):
5053
expected = expected.to(dtype=self.dtype, device=self.device)
5154
self.assertEqual(output, expected, rtol=rtol, atol=atol)
5255

53-
@common_utils.skipIfNoExec('apply-cmvn-sliding')
56+
@skipIfNoExec('apply-cmvn-sliding')
5457
def test_sliding_window_cmn(self):
5558
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
5659
kwargs = {
@@ -67,33 +70,33 @@ def test_sliding_window_cmn(self):
6770
self.assert_equal(result, expected=kaldi_result)
6871

6972
@parameterized.expand(load_params('kaldi_test_fbank_args.json'))
70-
@common_utils.skipIfNoExec('compute-fbank-feats')
73+
@skipIfNoExec('compute-fbank-feats')
7174
def test_fbank(self, kwargs):
7275
"""fbank should be numerically compatible with compute-fbank-feats"""
73-
wave_file = common_utils.get_asset_path('kaldi_file.wav')
74-
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
76+
wave_file = get_asset_path('kaldi_file.wav')
77+
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
7578
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
7679
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
7780
kaldi_result = _run_kaldi(command, 'scp', wave_file)
7881
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
7982

8083
@parameterized.expand(load_params('kaldi_test_spectrogram_args.json'))
81-
@common_utils.skipIfNoExec('compute-spectrogram-feats')
84+
@skipIfNoExec('compute-spectrogram-feats')
8285
def test_spectrogram(self, kwargs):
8386
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
84-
wave_file = common_utils.get_asset_path('kaldi_file.wav')
85-
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
87+
wave_file = get_asset_path('kaldi_file.wav')
88+
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
8689
result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
8790
command = ['compute-spectrogram-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
8891
kaldi_result = _run_kaldi(command, 'scp', wave_file)
8992
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
9093

9194
@parameterized.expand(load_params('kaldi_test_mfcc_args.json'))
92-
@common_utils.skipIfNoExec('compute-mfcc-feats')
95+
@skipIfNoExec('compute-mfcc-feats')
9396
def test_mfcc(self, kwargs):
9497
"""mfcc should be numerically compatible with compute-mfcc-feats"""
95-
wave_file = common_utils.get_asset_path('kaldi_file.wav')
96-
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
98+
wave_file = get_asset_path('kaldi_file.wav')
99+
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
97100
result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
98101
command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
99102
kaldi_result = _run_kaldi(command, 'scp', wave_file)

test/test_librosa_compatibility.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ class TestTransforms(common_utils.TorchaudioTestCase):
160160
"""Test suite for functions in `transforms` module."""
161161
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
162162
common_utils.set_audio_backend('default')
163-
sound, sample_rate = _load_audio_asset('sinewave.wav')
163+
path = common_utils.get_asset_path('sinewave.wav')
164+
sound, sample_rate = common_utils.load_wav(path)
164165
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
165166

166167
# test core spectrogram
@@ -300,9 +301,9 @@ def test_InverseMelScale(self):
300301
hop_length = n_fft // 4
301302

302303
# Prepare mel spectrogram input. We use torchaudio to compute one.
303-
common_utils.set_audio_backend('default')
304-
sound, sample_rate = _load_audio_asset(
305-
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
304+
path = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
305+
sound, sample_rate = common_utils.load_wav(path)
306+
sound = sound[:, 2**10:2**10 + 2**14]
306307
sound = sound.mean(dim=0, keepdim=True)
307308
spec_orig = F.spectrogram(
308309
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,

test/test_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_mu_law_companding(self):
4545

4646
def test_AmplitudeToDB(self):
4747
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
48-
waveform, sample_rate = torchaudio.load(filepath)
48+
waveform = common_utils.load_wav(filepath)[0]
4949

5050
mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
5151
power_to_db_transform = transforms.AmplitudeToDB('power', 80.)
@@ -115,7 +115,7 @@ def test_mel2(self):
115115
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
116116
# check on multi-channel audio
117117
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
118-
x_stereo, sr_stereo = torchaudio.load(filepath) # (2, 278756), 44100
118+
x_stereo = common_utils.load_wav(filepath)[0] # (2, 278756), 44100
119119
spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
120120
self.assertTrue(spectrogram_stereo.dim() == 3)
121121
self.assertTrue(spectrogram_stereo.size(0) == 2)
@@ -166,7 +166,7 @@ def test_mfcc(self):
166166

167167
def test_resample_size(self):
168168
input_path = common_utils.get_asset_path('sinewave.wav')
169-
waveform, sample_rate = torchaudio.load(input_path)
169+
waveform, sample_rate = common_utils.load_wav(input_path)
170170

171171
upsample_rate = sample_rate * 2
172172
downsample_rate = sample_rate // 2

test/torchscript_consistency_impl.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33

44
import torch
5-
import torchaudio
65
import torchaudio.functional as F
76
import torchaudio.transforms as T
87

@@ -616,6 +615,5 @@ def test_SlidingWindowCmn(self):
616615

617616
def test_Vad(self):
618617
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
619-
common_utils.set_audio_backend('default')
620-
waveform, sample_rate = torchaudio.load(filepath)
618+
waveform, sample_rate = common_utils.load_wav(filepath)
621619
self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)

0 commit comments

Comments
 (0)