Skip to content

Commit b34bc7d

Browse files
authored
Add SpeechCommands train/valid/test split (#966)
1 parent 51e7796 commit b34bc7d

File tree

2 files changed

+153
-34
lines changed

2 files changed

+153
-34
lines changed

test/torchaudio_unittest/datasets/speechcommands_test.py

Lines changed: 114 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,47 +53,60 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
5353

5454
root_dir = None
5555
samples = []
56+
train_samples = []
57+
valid_samples = []
58+
test_samples = []
5659

5760
@classmethod
58-
def setUp(cls):
61+
def setUpClass(cls):
5962
cls.root_dir = cls.get_base_temp_dir()
6063
dataset_dir = os.path.join(
6164
cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL
6265
)
6366
os.makedirs(dataset_dir, exist_ok=True)
6467
sample_rate = 16000 # 16kHz sample rate
6568
seed = 0
66-
for label in LABELS:
67-
path = os.path.join(dataset_dir, label)
68-
os.makedirs(path, exist_ok=True)
69-
for j in range(2):
70-
# generate hash ID for speaker
71-
speaker = "{:08x}".format(j)
72-
73-
for utterance in range(3):
74-
filename = f"{speaker}{speechcommands.HASH_DIVIDER}{utterance}.wav"
75-
file_path = os.path.join(path, filename)
76-
seed += 1
77-
data = get_whitenoise(
78-
sample_rate=sample_rate,
79-
duration=0.01,
80-
n_channels=1,
81-
dtype="int16",
82-
seed=seed,
83-
)
84-
save_wav(file_path, data, sample_rate)
85-
sample = (
86-
normalize_wav(data),
87-
sample_rate,
88-
label,
89-
speaker,
90-
utterance,
91-
)
92-
cls.samples.append(sample)
69+
valid_file = os.path.join(dataset_dir, "validation_list.txt")
70+
test_file = os.path.join(dataset_dir, "testing_list.txt")
71+
with open(valid_file, "w") as valid, open(test_file, "w") as test:
72+
for label in LABELS:
73+
path = os.path.join(dataset_dir, label)
74+
os.makedirs(path, exist_ok=True)
75+
for j in range(6):
76+
# generate hash ID for speaker
77+
speaker = "{:08x}".format(j)
78+
79+
for utterance in range(3):
80+
filename = f"{speaker}{speechcommands.HASH_DIVIDER}{utterance}.wav"
81+
file_path = os.path.join(path, filename)
82+
seed += 1
83+
data = get_whitenoise(
84+
sample_rate=sample_rate,
85+
duration=0.01,
86+
n_channels=1,
87+
dtype="int16",
88+
seed=seed,
89+
)
90+
save_wav(file_path, data, sample_rate)
91+
sample = (
92+
normalize_wav(data),
93+
sample_rate,
94+
label,
95+
speaker,
96+
utterance,
97+
)
98+
cls.samples.append(sample)
99+
if j < 2:
100+
cls.train_samples.append(sample)
101+
elif j < 4:
102+
valid.write(f'{label}/{filename}\n')
103+
cls.valid_samples.append(sample)
104+
elif j < 6:
105+
test.write(f'{label}/{filename}\n')
106+
cls.test_samples.append(sample)
93107

94108
def testSpeechCommands(self):
95109
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir)
96-
print(dataset._path)
97110

98111
num_samples = 0
99112
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
@@ -107,3 +120,75 @@ def testSpeechCommands(self):
107120
num_samples += 1
108121

109122
assert num_samples == len(self.samples)
123+
124+
def testSpeechCommandsNone(self):
125+
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset=None)
126+
127+
num_samples = 0
128+
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
129+
dataset
130+
):
131+
self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8)
132+
assert sample_rate == self.samples[i][1]
133+
assert label == self.samples[i][2]
134+
assert speaker_id == self.samples[i][3]
135+
assert utterance_number == self.samples[i][4]
136+
num_samples += 1
137+
138+
assert num_samples == len(self.samples)
139+
140+
def testSpeechCommandsSubsetTrain(self):
141+
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")
142+
143+
num_samples = 0
144+
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
145+
dataset
146+
):
147+
self.assertEqual(data, self.train_samples[i][0], atol=5e-5, rtol=1e-8)
148+
assert sample_rate == self.train_samples[i][1]
149+
assert label == self.train_samples[i][2]
150+
assert speaker_id == self.train_samples[i][3]
151+
assert utterance_number == self.train_samples[i][4]
152+
num_samples += 1
153+
154+
assert num_samples == len(self.train_samples)
155+
156+
def testSpeechCommandsSubsetValid(self):
157+
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
158+
159+
num_samples = 0
160+
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
161+
dataset
162+
):
163+
self.assertEqual(data, self.valid_samples[i][0], atol=5e-5, rtol=1e-8)
164+
assert sample_rate == self.valid_samples[i][1]
165+
assert label == self.valid_samples[i][2]
166+
assert speaker_id == self.valid_samples[i][3]
167+
assert utterance_number == self.valid_samples[i][4]
168+
num_samples += 1
169+
170+
assert num_samples == len(self.valid_samples)
171+
172+
def testSpeechCommandsSubsetTest(self):
173+
dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")
174+
175+
num_samples = 0
176+
for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate(
177+
dataset
178+
):
179+
self.assertEqual(data, self.test_samples[i][0], atol=5e-5, rtol=1e-8)
180+
assert sample_rate == self.test_samples[i][1]
181+
assert label == self.test_samples[i][2]
182+
assert speaker_id == self.test_samples[i][3]
183+
assert utterance_number == self.test_samples[i][4]
184+
num_samples += 1
185+
186+
assert num_samples == len(self.test_samples)
187+
188+
def testSpeechCommandsSum(self):
189+
dataset_all = speechcommands.SPEECHCOMMANDS(self.root_dir)
190+
dataset_train = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training")
191+
dataset_valid = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation")
192+
dataset_test = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing")
193+
194+
assert len(dataset_train) + len(dataset_valid) + len(dataset_test) == len(dataset_all)

torchaudio/datasets/speechcommands.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Tuple
2+
from typing import Tuple, Optional
33

44
import torchaudio
55
from torch.utils.data import Dataset
@@ -22,6 +22,15 @@
2222
}
2323

2424

25+
def _load_list(root, *filenames):
26+
output = []
27+
for filename in filenames:
28+
filepath = os.path.join(root, filename)
29+
with open(filepath) as fileobj:
30+
output += [os.path.normpath(os.path.join(root, line.strip())) for line in fileobj]
31+
return output
32+
33+
2534
def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]:
2635
relpath = os.path.relpath(filepath, path)
2736
label, filename = os.path.split(relpath)
@@ -48,13 +57,25 @@ class SPEECHCOMMANDS(Dataset):
4857
The top-level directory of the dataset. (default: ``"SpeechCommands"``)
4958
download (bool, optional):
5059
Whether to download the dataset if it is not found at root path. (default: ``False``).
60+
subset (Optional[str]):
61+
Select a subset of the dataset [None, "training", "validation", "testing"]. None means
62+
the whole dataset. "validation" and "testing" are defined in "validation_list.txt" and
63+
"testing_list.txt", respectively, and "training" is the rest. (default: ``None``)
5164
"""
5265

5366
def __init__(self,
5467
root: str,
5568
url: str = URL,
5669
folder_in_archive: str = FOLDER_IN_ARCHIVE,
57-
download: bool = False) -> None:
70+
download: bool = False,
71+
subset: Optional[str] = None,
72+
) -> None:
73+
74+
assert subset is None or subset in ["training", "validation", "testing"], (
75+
"When `subset` not None, it must take a value from "
76+
+ "{'training', 'validation', 'testing'}."
77+
)
78+
5879
if url in [
5980
"speech_commands_v0.01",
6081
"speech_commands_v0.02",
@@ -79,9 +100,22 @@ def __init__(self,
79100
download_url(url, root, hash_value=checksum, hash_type="md5")
80101
extract_archive(archive, self._path)
81102

82-
walker = walk_files(self._path, suffix=".wav", prefix=True)
83-
walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
84-
self._walker = list(walker)
103+
if subset == "validation":
104+
self._walker = _load_list(self._path, "validation_list.txt")
105+
elif subset == "testing":
106+
self._walker = _load_list(self._path, "testing_list.txt")
107+
elif subset == "training":
108+
excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt"))
109+
walker = walk_files(self._path, suffix=".wav", prefix=True)
110+
self._walker = [
111+
w for w in walker
112+
if HASH_DIVIDER in w
113+
and EXCEPT_FOLDER not in w
114+
and os.path.normpath(w) not in excludes
115+
]
116+
else:
117+
walker = walk_files(self._path, suffix=".wav", prefix=True)
118+
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
85119

86120
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
87121
"""Load the n-th sample from the dataset.

0 commit comments

Comments
 (0)