Skip to content

Commit 102174e

Browse files
authored
Generate YESNO dataset on-the-fly for test (#792)
1 parent 02b898f commit 102174e

File tree

4 files changed

+58
-17
lines changed

4 files changed

+58
-17
lines changed
-84 Bytes
Binary file not shown.

test/common_utils/backend_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def supports_mp3(backend):
2929
def set_audio_backend(backend):
3030
"""Allow additional backend value, 'default'"""
3131
if backend == 'default':
32-
if 'sox' in BACKENDS:
33-
be = 'sox'
32+
if 'sox_io' in BACKENDS:
33+
be = 'sox_io'
3434
elif 'soundfile' in BACKENDS:
3535
be = 'soundfile'
3636
else:

test/common_utils/test_case_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ class TempDirMixin:
1515
"""Mixin to provide easy access to temp dir"""
1616
temp_dir_ = None
1717

18-
@property
19-
def base_temp_dir(self):
18+
@classmethod
19+
def get_base_temp_dir(cls):
2020
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
2121
# this is handy for debugging.
2222
key = 'TORCHAUDIO_TEST_TEMP_DIR'
2323
if key in os.environ:
2424
return os.environ[key]
25-
if self.__class__.temp_dir_ is None:
26-
self.__class__.temp_dir_ = tempfile.TemporaryDirectory()
27-
return self.__class__.temp_dir_.name
25+
if cls.temp_dir_ is None:
26+
cls.temp_dir_ = tempfile.TemporaryDirectory()
27+
return cls.temp_dir_.name
2828

2929
@classmethod
3030
def tearDownClass(cls):
@@ -34,7 +34,7 @@ def tearDownClass(cls):
3434
cls.temp_dir_ = None
3535

3636
def get_temp_path(self, *paths):
37-
temp_dir = os.path.join(self.base_temp_dir, self.id())
37+
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
3838
path = os.path.join(temp_dir, *paths)
3939
os.makedirs(os.path.dirname(path), exist_ok=True)
4040
return path

test/test_datasets.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import unittest
23

34
from torchaudio.datasets.commonvoice import COMMONVOICE
@@ -10,16 +11,19 @@
1011
from torchaudio.datasets.gtzan import GTZAN
1112
from torchaudio.datasets.cmuarctic import CMUARCTIC
1213

13-
from . import common_utils
14+
from .common_utils import (
15+
TempDirMixin,
16+
TorchaudioTestCase,
17+
get_asset_path,
18+
get_whitenoise,
19+
save_wav,
20+
normalize_wav,
21+
)
1422

1523

16-
class TestDatasets(common_utils.TorchaudioTestCase):
24+
class TestDatasets(TorchaudioTestCase):
1725
backend = 'default'
18-
path = common_utils.get_asset_path()
19-
20-
def test_yesno(self):
21-
data = YESNO(self.path)
22-
data[0]
26+
path = get_asset_path()
2327

2428
def test_vctk(self):
2529
data = VCTK(self.path)
@@ -46,9 +50,9 @@ def test_cmuarctic(self):
4650
data[0]
4751

4852

49-
class TestCommonVoice(common_utils.TorchaudioTestCase):
53+
class TestCommonVoice(TorchaudioTestCase):
5054
backend = 'default'
51-
path = common_utils.get_asset_path()
55+
path = get_asset_path()
5256

5357
def test_commonvoice(self):
5458
data = COMMONVOICE(self.path, url="tatar")
@@ -69,5 +73,42 @@ def test_commonvoice_bg(self):
6973
pass
7074

7175

76+
class TestYesNo(TempDirMixin, TorchaudioTestCase):
77+
backend = 'default'
78+
79+
root_dir = None
80+
data = []
81+
labels = [
82+
[0, 0, 0, 0, 0, 0, 0, 0],
83+
[0, 0, 0, 0, 1, 1, 1, 1],
84+
[0, 1, 0, 1, 0, 1, 1, 0],
85+
[1, 1, 1, 1, 0, 0, 0, 0],
86+
[1, 1, 1, 1, 1, 1, 1, 1],
87+
]
88+
89+
@classmethod
90+
def setUpClass(cls):
91+
cls.root_dir = cls.get_base_temp_dir()
92+
base_dir = os.path.join(cls.root_dir, 'waves_yesno')
93+
os.makedirs(base_dir, exist_ok=True)
94+
for label in cls.labels:
95+
filename = f'{"_".join(str(l) for l in label)}.wav'
96+
path = os.path.join(base_dir, filename)
97+
data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16')
98+
save_wav(path, data, 8000)
99+
cls.data.append(normalize_wav(data))
100+
101+
def test_yesno(self):
102+
dataset = YESNO(self.root_dir)
103+
samples = list(dataset)
104+
samples.sort(key=lambda s: s[2])
105+
for i, (waveform, sample_rate, label) in enumerate(samples):
106+
expected_label = self.labels[i]
107+
expected_data = self.data[i]
108+
self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
109+
assert sample_rate == 8000
110+
assert label == expected_label
111+
112+
72113
if __name__ == "__main__":
73114
unittest.main()

0 commit comments

Comments
 (0)