1+ import os
12import unittest
23
34from torchaudio .datasets .commonvoice import COMMONVOICE
1011from torchaudio .datasets .gtzan import GTZAN
1112from torchaudio .datasets .cmuarctic import CMUARCTIC
1213
13- from . import common_utils
14+ from .common_utils import (
15+ TempDirMixin ,
16+ TorchaudioTestCase ,
17+ get_asset_path ,
18+ get_whitenoise ,
19+ save_wav ,
20+ normalize_wav ,
21+ )
1422
1523
16- class TestDatasets (common_utils . TorchaudioTestCase ):
24+ class TestDatasets (TorchaudioTestCase ):
1725 backend = 'default'
18- path = common_utils .get_asset_path ()
19-
20- def test_yesno (self ):
21- data = YESNO (self .path )
22- data [0 ]
26+ path = get_asset_path ()
2327
2428 def test_vctk (self ):
2529 data = VCTK (self .path )
@@ -46,9 +50,9 @@ def test_cmuarctic(self):
4650 data [0 ]
4751
4852
49- class TestCommonVoice (common_utils . TorchaudioTestCase ):
53+ class TestCommonVoice (TorchaudioTestCase ):
5054 backend = 'default'
51- path = common_utils . get_asset_path ()
55+ path = get_asset_path ()
5256
5357 def test_commonvoice (self ):
5458 data = COMMONVOICE (self .path , url = "tatar" )
@@ -69,5 +73,42 @@ def test_commonvoice_bg(self):
6973 pass
7074
7175
76+ class TestYesNo (TempDirMixin , TorchaudioTestCase ):
77+ backend = 'default'
78+
79+ root_dir = None
80+ data = []
81+ labels = [
82+ [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
83+ [0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ],
84+ [0 , 1 , 0 , 1 , 0 , 1 , 1 , 0 ],
85+ [1 , 1 , 1 , 1 , 0 , 0 , 0 , 0 ],
86+ [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ],
87+ ]
88+
89+ @classmethod
90+ def setUpClass (cls ):
91+ cls .root_dir = cls .get_base_temp_dir ()
92+ base_dir = os .path .join (cls .root_dir , 'waves_yesno' )
93+ os .makedirs (base_dir , exist_ok = True )
94+ for label in cls .labels :
95+ filename = f'{ "_" .join (str (l ) for l in label )} .wav'
96+ path = os .path .join (base_dir , filename )
97+ data = get_whitenoise (sample_rate = 8000 , duration = 6 , n_channels = 1 , dtype = 'int16' )
98+ save_wav (path , data , 8000 )
99+ cls .data .append (normalize_wav (data ))
100+
101+ def test_yesno (self ):
102+ dataset = YESNO (self .root_dir )
103+ samples = list (dataset )
104+ samples .sort (key = lambda s : s [2 ])
105+ for i , (waveform , sample_rate , label ) in enumerate (samples ):
106+ expected_label = self .labels [i ]
107+ expected_data = self .data [i ]
108+ self .assertEqual (expected_data , waveform , atol = 5e-5 , rtol = 1e-8 )
109+ assert sample_rate == 8000
110+ assert label == expected_label
111+
112+
72113if __name__ == "__main__" :
73114 unittest .main ()
0 commit comments