@@ -53,47 +53,60 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
5353
5454 root_dir = None
5555 samples = []
56+ train_samples = []
57+ valid_samples = []
58+ test_samples = []
5659
5760 @classmethod
58- def setUp (cls ):
61+ def setUpClass (cls ):
5962 cls .root_dir = cls .get_base_temp_dir ()
6063 dataset_dir = os .path .join (
6164 cls .root_dir , speechcommands .FOLDER_IN_ARCHIVE , speechcommands .URL
6265 )
6366 os .makedirs (dataset_dir , exist_ok = True )
6467 sample_rate = 16000 # 16kHz sample rate
6568 seed = 0
66- for label in LABELS :
67- path = os .path .join (dataset_dir , label )
68- os .makedirs (path , exist_ok = True )
69- for j in range (2 ):
70- # generate hash ID for speaker
71- speaker = "{:08x}" .format (j )
72-
73- for utterance in range (3 ):
74- filename = f"{ speaker } { speechcommands .HASH_DIVIDER } { utterance } .wav"
75- file_path = os .path .join (path , filename )
76- seed += 1
77- data = get_whitenoise (
78- sample_rate = sample_rate ,
79- duration = 0.01 ,
80- n_channels = 1 ,
81- dtype = "int16" ,
82- seed = seed ,
83- )
84- save_wav (file_path , data , sample_rate )
85- sample = (
86- normalize_wav (data ),
87- sample_rate ,
88- label ,
89- speaker ,
90- utterance ,
91- )
92- cls .samples .append (sample )
69+ valid_file = os .path .join (dataset_dir , "validation_list.txt" )
70+ test_file = os .path .join (dataset_dir , "testing_list.txt" )
71+ with open (valid_file , "w" ) as valid , open (test_file , "w" ) as test :
72+ for label in LABELS :
73+ path = os .path .join (dataset_dir , label )
74+ os .makedirs (path , exist_ok = True )
75+ for j in range (6 ):
76+ # generate hash ID for speaker
77+ speaker = "{:08x}" .format (j )
78+
79+ for utterance in range (3 ):
80+ filename = f"{ speaker } { speechcommands .HASH_DIVIDER } { utterance } .wav"
81+ file_path = os .path .join (path , filename )
82+ seed += 1
83+ data = get_whitenoise (
84+ sample_rate = sample_rate ,
85+ duration = 0.01 ,
86+ n_channels = 1 ,
87+ dtype = "int16" ,
88+ seed = seed ,
89+ )
90+ save_wav (file_path , data , sample_rate )
91+ sample = (
92+ normalize_wav (data ),
93+ sample_rate ,
94+ label ,
95+ speaker ,
96+ utterance ,
97+ )
98+ cls .samples .append (sample )
99+ if j < 2 :
100+ cls .train_samples .append (sample )
101+ elif j < 4 :
102+ valid .write (f'{ label } /{ filename } \n ' )
103+ cls .valid_samples .append (sample )
104+ elif j < 6 :
105+ test .write (f'{ label } /{ filename } \n ' )
106+ cls .test_samples .append (sample )
93107
94108 def testSpeechCommands (self ):
95109 dataset = speechcommands .SPEECHCOMMANDS (self .root_dir )
96- print (dataset ._path )
97110
98111 num_samples = 0
99112 for i , (data , sample_rate , label , speaker_id , utterance_number ) in enumerate (
@@ -107,3 +120,75 @@ def testSpeechCommands(self):
107120 num_samples += 1
108121
109122 assert num_samples == len (self .samples )
123+
124+ def testSpeechCommandsNone (self ):
125+ dataset = speechcommands .SPEECHCOMMANDS (self .root_dir , subset = None )
126+
127+ num_samples = 0
128+ for i , (data , sample_rate , label , speaker_id , utterance_number ) in enumerate (
129+ dataset
130+ ):
131+ self .assertEqual (data , self .samples [i ][0 ], atol = 5e-5 , rtol = 1e-8 )
132+ assert sample_rate == self .samples [i ][1 ]
133+ assert label == self .samples [i ][2 ]
134+ assert speaker_id == self .samples [i ][3 ]
135+ assert utterance_number == self .samples [i ][4 ]
136+ num_samples += 1
137+
138+ assert num_samples == len (self .samples )
139+
140+ def testSpeechCommandsSubsetTrain (self ):
141+ dataset = speechcommands .SPEECHCOMMANDS (self .root_dir , subset = "training" )
142+
143+ num_samples = 0
144+ for i , (data , sample_rate , label , speaker_id , utterance_number ) in enumerate (
145+ dataset
146+ ):
147+ self .assertEqual (data , self .train_samples [i ][0 ], atol = 5e-5 , rtol = 1e-8 )
148+ assert sample_rate == self .train_samples [i ][1 ]
149+ assert label == self .train_samples [i ][2 ]
150+ assert speaker_id == self .train_samples [i ][3 ]
151+ assert utterance_number == self .train_samples [i ][4 ]
152+ num_samples += 1
153+
154+ assert num_samples == len (self .train_samples )
155+
156+ def testSpeechCommandsSubsetValid (self ):
157+ dataset = speechcommands .SPEECHCOMMANDS (self .root_dir , subset = "validation" )
158+
159+ num_samples = 0
160+ for i , (data , sample_rate , label , speaker_id , utterance_number ) in enumerate (
161+ dataset
162+ ):
163+ self .assertEqual (data , self .valid_samples [i ][0 ], atol = 5e-5 , rtol = 1e-8 )
164+ assert sample_rate == self .valid_samples [i ][1 ]
165+ assert label == self .valid_samples [i ][2 ]
166+ assert speaker_id == self .valid_samples [i ][3 ]
167+ assert utterance_number == self .valid_samples [i ][4 ]
168+ num_samples += 1
169+
170+ assert num_samples == len (self .valid_samples )
171+
172+ def testSpeechCommandsSubsetTest (self ):
173+ dataset = speechcommands .SPEECHCOMMANDS (self .root_dir , subset = "testing" )
174+
175+ num_samples = 0
176+ for i , (data , sample_rate , label , speaker_id , utterance_number ) in enumerate (
177+ dataset
178+ ):
179+ self .assertEqual (data , self .test_samples [i ][0 ], atol = 5e-5 , rtol = 1e-8 )
180+ assert sample_rate == self .test_samples [i ][1 ]
181+ assert label == self .test_samples [i ][2 ]
182+ assert speaker_id == self .test_samples [i ][3 ]
183+ assert utterance_number == self .test_samples [i ][4 ]
184+ num_samples += 1
185+
186+ assert num_samples == len (self .test_samples )
187+
188+ def testSpeechCommandsSum (self ):
189+ dataset_all = speechcommands .SPEECHCOMMANDS (self .root_dir )
190+ dataset_train = speechcommands .SPEECHCOMMANDS (self .root_dir , subset = "training" )
191+ dataset_valid = speechcommands .SPEECHCOMMANDS (self .root_dir , subset = "validation" )
192+ dataset_test = speechcommands .SPEECHCOMMANDS (self .root_dir , subset = "testing" )
193+
194+ assert len (dataset_train ) + len (dataset_valid ) + len (dataset_test ) == len (dataset_all )
0 commit comments