Skip to content

Commit 19d8f1c

Browse files
authored
Refactor integration test (#1922)
- Make the test support other languages - Fetch tetst asset on-the-fly
1 parent 716aa41 commit 19d8f1c

File tree

3 files changed

+34
-17
lines changed

3 files changed

+34
-17
lines changed

test/integration_tests/conftest.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torchaudio_unittest.common_utils import get_asset_path
2+
import requests
33
import pytest
44

55

@@ -32,6 +32,22 @@ def ctc_decoder():
3232
return GreedyCTCDecoder
3333

3434

35+
_FILES = {
36+
'en': 'Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac',
37+
}
38+
39+
3540
@pytest.fixture
36-
def sample_speech_16000_en():
37-
return get_asset_path('Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac')
41+
def sample_speech(tmp_path, lang):
42+
if lang not in _FILES:
43+
raise NotImplementedError(f'Unexpected lang: {lang}')
44+
filename = _FILES[lang]
45+
path = tmp_path.parent / filename
46+
if not path.exists():
47+
url = f'https://download.pytorch.org/torchaudio/test-assets/{filename}'
48+
print(f'downloading from {url}')
49+
with open(path, 'wb') as file:
50+
with requests.get(url) as resp:
51+
resp.raise_for_status()
52+
file.write(resp.content)
53+
return path

test/integration_tests/wav2vec2_pipeline_test.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,31 @@ def test_pretraining_models(bundle):
4040

4141

4242
@pytest.mark.parametrize(
43-
"bundle,expected",
43+
"bundle,lang,expected",
4444
[
45-
(WAV2VEC2_ASR_BASE_10M, 'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'),
46-
(WAV2VEC2_ASR_BASE_100H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
47-
(WAV2VEC2_ASR_BASE_960H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
48-
(WAV2VEC2_ASR_LARGE_10M, 'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'),
49-
(WAV2VEC2_ASR_LARGE_100H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
50-
(WAV2VEC2_ASR_LARGE_960H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
51-
(WAV2VEC2_ASR_LARGE_LV60K_10M, 'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'),
52-
(WAV2VEC2_ASR_LARGE_LV60K_100H, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
53-
(WAV2VEC2_ASR_LARGE_LV60K_960H, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
54-
(HUBERT_ASR_LARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
55-
(HUBERT_ASR_XLARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|')
45+
(WAV2VEC2_ASR_BASE_10M, 'en', 'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'),
46+
(WAV2VEC2_ASR_BASE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
47+
(WAV2VEC2_ASR_BASE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
48+
(WAV2VEC2_ASR_LARGE_10M, 'en', 'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'),
49+
(WAV2VEC2_ASR_LARGE_100H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
50+
(WAV2VEC2_ASR_LARGE_960H, 'en', 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
51+
(WAV2VEC2_ASR_LARGE_LV60K_10M, 'en', 'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'),
52+
(WAV2VEC2_ASR_LARGE_LV60K_100H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
53+
(WAV2VEC2_ASR_LARGE_LV60K_960H, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
54+
(HUBERT_ASR_LARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
55+
(HUBERT_ASR_XLARGE, 'en', 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'),
5656
]
5757
)
5858
def test_finetune_asr_model(
5959
bundle,
60+
lang,
6061
expected,
61-
sample_speech_16000_en,
62+
sample_speech,
6263
ctc_decoder,
6364
):
6465
"""Smoke test of downloading weights for fine-tuning models and simple transcription"""
6566
model = bundle.get_model().eval()
66-
waveform, sample_rate = torchaudio.load(sample_speech_16000_en)
67+
waveform, sample_rate = torchaudio.load(sample_speech)
6768
emission, _ = model(waveform)
6869
decoder = ctc_decoder(bundle.get_labels())
6970
result = decoder(emission[0])
Binary file not shown.

0 commit comments

Comments
 (0)